1 /* 2 * linux/net/sunrpc/auth.c 3 * 4 * Generic RPC client authentication API. 5 * 6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> 7 */ 8 9 #include <linux/types.h> 10 #include <linux/sched.h> 11 #include <linux/module.h> 12 #include <linux/slab.h> 13 #include <linux/errno.h> 14 #include <linux/sunrpc/clnt.h> 15 #include <linux/spinlock.h> 16 17 #ifdef RPC_DEBUG 18 # define RPCDBG_FACILITY RPCDBG_AUTH 19 #endif 20 21 static struct rpc_authops * auth_flavors[RPC_AUTH_MAXFLAVOR] = { 22 &authnull_ops, /* AUTH_NULL */ 23 &authunix_ops, /* AUTH_UNIX */ 24 NULL, /* others can be loadable modules */ 25 }; 26 27 static u32 28 pseudoflavor_to_flavor(u32 flavor) { 29 if (flavor >= RPC_AUTH_MAXFLAVOR) 30 return RPC_AUTH_GSS; 31 return flavor; 32 } 33 34 int 35 rpcauth_register(struct rpc_authops *ops) 36 { 37 rpc_authflavor_t flavor; 38 39 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 40 return -EINVAL; 41 if (auth_flavors[flavor] != NULL) 42 return -EPERM; /* what else? */ 43 auth_flavors[flavor] = ops; 44 return 0; 45 } 46 47 int 48 rpcauth_unregister(struct rpc_authops *ops) 49 { 50 rpc_authflavor_t flavor; 51 52 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 53 return -EINVAL; 54 if (auth_flavors[flavor] != ops) 55 return -EPERM; /* what else? */ 56 auth_flavors[flavor] = NULL; 57 return 0; 58 } 59 60 struct rpc_auth * 61 rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) 62 { 63 struct rpc_auth *auth; 64 struct rpc_authops *ops; 65 u32 flavor = pseudoflavor_to_flavor(pseudoflavor); 66 67 auth = ERR_PTR(-EINVAL); 68 if (flavor >= RPC_AUTH_MAXFLAVOR) 69 goto out; 70 71 /* FIXME - auth_flavors[] really needs an rw lock, 72 * and module refcounting. */ 73 #ifdef CONFIG_KMOD 74 if ((ops = auth_flavors[flavor]) == NULL) 75 request_module("rpc-auth-%u", flavor); 76 #endif 77 if ((ops = auth_flavors[flavor]) == NULL) 78 goto out; 79 auth = ops->create(clnt, pseudoflavor); 80 if (IS_ERR(auth)) 81 return auth; 82 if (clnt->cl_auth) 83 rpcauth_destroy(clnt->cl_auth); 84 clnt->cl_auth = auth; 85 86 out: 87 return auth; 88 } 89 90 void 91 rpcauth_destroy(struct rpc_auth *auth) 92 { 93 if (!atomic_dec_and_test(&auth->au_count)) 94 return; 95 auth->au_ops->destroy(auth); 96 } 97 98 static DEFINE_SPINLOCK(rpc_credcache_lock); 99 100 /* 101 * Initialize RPC credential cache 102 */ 103 int 104 rpcauth_init_credcache(struct rpc_auth *auth, unsigned long expire) 105 { 106 struct rpc_cred_cache *new; 107 int i; 108 109 new = kmalloc(sizeof(*new), GFP_KERNEL); 110 if (!new) 111 return -ENOMEM; 112 for (i = 0; i < RPC_CREDCACHE_NR; i++) 113 INIT_HLIST_HEAD(&new->hashtable[i]); 114 new->expire = expire; 115 new->nextgc = jiffies + (expire >> 1); 116 auth->au_credcache = new; 117 return 0; 118 } 119 120 /* 121 * Destroy a list of credentials 122 */ 123 static inline 124 void rpcauth_destroy_credlist(struct hlist_head *head) 125 { 126 struct rpc_cred *cred; 127 128 while (!hlist_empty(head)) { 129 cred = hlist_entry(head->first, struct rpc_cred, cr_hash); 130 hlist_del_init(&cred->cr_hash); 131 put_rpccred(cred); 132 } 133 } 134 135 /* 136 * Clear the RPC credential cache, and delete those credentials 137 * that are not referenced. 138 */ 139 void 140 rpcauth_free_credcache(struct rpc_auth *auth) 141 { 142 struct rpc_cred_cache *cache = auth->au_credcache; 143 HLIST_HEAD(free); 144 struct hlist_node *pos, *next; 145 struct rpc_cred *cred; 146 int i; 147 148 spin_lock(&rpc_credcache_lock); 149 for (i = 0; i < RPC_CREDCACHE_NR; i++) { 150 hlist_for_each_safe(pos, next, &cache->hashtable[i]) { 151 cred = hlist_entry(pos, struct rpc_cred, cr_hash); 152 __hlist_del(&cred->cr_hash); 153 hlist_add_head(&cred->cr_hash, &free); 154 } 155 } 156 spin_unlock(&rpc_credcache_lock); 157 rpcauth_destroy_credlist(&free); 158 } 159 160 static void 161 rpcauth_prune_expired(struct rpc_auth *auth, struct rpc_cred *cred, struct hlist_head *free) 162 { 163 if (atomic_read(&cred->cr_count) != 1) 164 return; 165 if (time_after(jiffies, cred->cr_expire + auth->au_credcache->expire)) 166 cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE; 167 if (!(cred->cr_flags & RPCAUTH_CRED_UPTODATE)) { 168 __hlist_del(&cred->cr_hash); 169 hlist_add_head(&cred->cr_hash, free); 170 } 171 } 172 173 /* 174 * Remove stale credentials. Avoid sleeping inside the loop. 175 */ 176 static void 177 rpcauth_gc_credcache(struct rpc_auth *auth, struct hlist_head *free) 178 { 179 struct rpc_cred_cache *cache = auth->au_credcache; 180 struct hlist_node *pos, *next; 181 struct rpc_cred *cred; 182 int i; 183 184 dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth); 185 for (i = 0; i < RPC_CREDCACHE_NR; i++) { 186 hlist_for_each_safe(pos, next, &cache->hashtable[i]) { 187 cred = hlist_entry(pos, struct rpc_cred, cr_hash); 188 rpcauth_prune_expired(auth, cred, free); 189 } 190 } 191 cache->nextgc = jiffies + cache->expire; 192 } 193 194 /* 195 * Look up a process' credentials in the authentication cache 196 */ 197 struct rpc_cred * 198 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, 199 int flags) 200 { 201 struct rpc_cred_cache *cache = auth->au_credcache; 202 HLIST_HEAD(free); 203 struct hlist_node *pos, *next; 204 struct rpc_cred *new = NULL, 205 *cred = NULL; 206 int nr = 0; 207 208 if (!(flags & RPCAUTH_LOOKUP_ROOTCREDS)) 209 nr = acred->uid & RPC_CREDCACHE_MASK; 210 retry: 211 spin_lock(&rpc_credcache_lock); 212 if (time_before(cache->nextgc, jiffies)) 213 rpcauth_gc_credcache(auth, &free); 214 hlist_for_each_safe(pos, next, &cache->hashtable[nr]) { 215 struct rpc_cred *entry; 216 entry = hlist_entry(pos, struct rpc_cred, cr_hash); 217 if (entry->cr_ops->crmatch(acred, entry, flags)) { 218 hlist_del(&entry->cr_hash); 219 cred = entry; 220 break; 221 } 222 rpcauth_prune_expired(auth, entry, &free); 223 } 224 if (new) { 225 if (cred) 226 hlist_add_head(&new->cr_hash, &free); 227 else 228 cred = new; 229 } 230 if (cred) { 231 hlist_add_head(&cred->cr_hash, &cache->hashtable[nr]); 232 get_rpccred(cred); 233 } 234 spin_unlock(&rpc_credcache_lock); 235 236 rpcauth_destroy_credlist(&free); 237 238 if (!cred) { 239 new = auth->au_ops->crcreate(auth, acred, flags); 240 if (!IS_ERR(new)) { 241 #ifdef RPC_DEBUG 242 new->cr_magic = RPCAUTH_CRED_MAGIC; 243 #endif 244 goto retry; 245 } else 246 cred = new; 247 } else if ((cred->cr_flags & RPCAUTH_CRED_NEW) 248 && cred->cr_ops->cr_init != NULL 249 && !(flags & RPCAUTH_LOOKUP_NEW)) { 250 int res = cred->cr_ops->cr_init(auth, cred); 251 if (res < 0) { 252 put_rpccred(cred); 253 cred = ERR_PTR(res); 254 } 255 } 256 257 return (struct rpc_cred *) cred; 258 } 259 260 struct rpc_cred * 261 rpcauth_lookupcred(struct rpc_auth *auth, int flags) 262 { 263 struct auth_cred acred = { 264 .uid = current->fsuid, 265 .gid = current->fsgid, 266 .group_info = current->group_info, 267 }; 268 struct rpc_cred *ret; 269 270 dprintk("RPC: looking up %s cred\n", 271 auth->au_ops->au_name); 272 get_group_info(acred.group_info); 273 ret = auth->au_ops->lookup_cred(auth, &acred, flags); 274 put_group_info(acred.group_info); 275 return ret; 276 } 277 278 struct rpc_cred * 279 rpcauth_bindcred(struct rpc_task *task) 280 { 281 struct rpc_auth *auth = task->tk_auth; 282 struct auth_cred acred = { 283 .uid = current->fsuid, 284 .gid = current->fsgid, 285 .group_info = current->group_info, 286 }; 287 struct rpc_cred *ret; 288 int flags = 0; 289 290 dprintk("RPC: %5u looking up %s cred\n", 291 task->tk_pid, task->tk_auth->au_ops->au_name); 292 get_group_info(acred.group_info); 293 if (task->tk_flags & RPC_TASK_ROOTCREDS) 294 flags |= RPCAUTH_LOOKUP_ROOTCREDS; 295 ret = auth->au_ops->lookup_cred(auth, &acred, flags); 296 if (!IS_ERR(ret)) 297 task->tk_msg.rpc_cred = ret; 298 else 299 task->tk_status = PTR_ERR(ret); 300 put_group_info(acred.group_info); 301 return ret; 302 } 303 304 void 305 rpcauth_holdcred(struct rpc_task *task) 306 { 307 dprintk("RPC: %5u holding %s cred %p\n", 308 task->tk_pid, task->tk_auth->au_ops->au_name, 309 task->tk_msg.rpc_cred); 310 if (task->tk_msg.rpc_cred) 311 get_rpccred(task->tk_msg.rpc_cred); 312 } 313 314 void 315 put_rpccred(struct rpc_cred *cred) 316 { 317 cred->cr_expire = jiffies; 318 if (!atomic_dec_and_test(&cred->cr_count)) 319 return; 320 cred->cr_ops->crdestroy(cred); 321 } 322 323 void 324 rpcauth_unbindcred(struct rpc_task *task) 325 { 326 struct rpc_cred *cred = task->tk_msg.rpc_cred; 327 328 dprintk("RPC: %5u releasing %s cred %p\n", 329 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 330 331 put_rpccred(cred); 332 task->tk_msg.rpc_cred = NULL; 333 } 334 335 __be32 * 336 rpcauth_marshcred(struct rpc_task *task, __be32 *p) 337 { 338 struct rpc_cred *cred = task->tk_msg.rpc_cred; 339 340 dprintk("RPC: %5u marshaling %s cred %p\n", 341 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 342 343 return cred->cr_ops->crmarshal(task, p); 344 } 345 346 __be32 * 347 rpcauth_checkverf(struct rpc_task *task, __be32 *p) 348 { 349 struct rpc_cred *cred = task->tk_msg.rpc_cred; 350 351 dprintk("RPC: %5u validating %s cred %p\n", 352 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 353 354 return cred->cr_ops->crvalidate(task, p); 355 } 356 357 int 358 rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp, 359 __be32 *data, void *obj) 360 { 361 struct rpc_cred *cred = task->tk_msg.rpc_cred; 362 363 dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", 364 task->tk_pid, cred->cr_ops->cr_name, cred); 365 if (cred->cr_ops->crwrap_req) 366 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); 367 /* By default, we encode the arguments normally. */ 368 return encode(rqstp, data, obj); 369 } 370 371 int 372 rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, 373 __be32 *data, void *obj) 374 { 375 struct rpc_cred *cred = task->tk_msg.rpc_cred; 376 377 dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", 378 task->tk_pid, cred->cr_ops->cr_name, cred); 379 if (cred->cr_ops->crunwrap_resp) 380 return cred->cr_ops->crunwrap_resp(task, decode, rqstp, 381 data, obj); 382 /* By default, we decode the arguments normally. */ 383 return decode(rqstp, data, obj); 384 } 385 386 int 387 rpcauth_refreshcred(struct rpc_task *task) 388 { 389 struct rpc_cred *cred = task->tk_msg.rpc_cred; 390 int err; 391 392 dprintk("RPC: %5u refreshing %s cred %p\n", 393 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 394 395 err = cred->cr_ops->crrefresh(task); 396 if (err < 0) 397 task->tk_status = err; 398 return err; 399 } 400 401 void 402 rpcauth_invalcred(struct rpc_task *task) 403 { 404 dprintk("RPC: %5u invalidating %s cred %p\n", 405 task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred); 406 spin_lock(&rpc_credcache_lock); 407 if (task->tk_msg.rpc_cred) 408 task->tk_msg.rpc_cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE; 409 spin_unlock(&rpc_credcache_lock); 410 } 411 412 int 413 rpcauth_uptodatecred(struct rpc_task *task) 414 { 415 return !(task->tk_msg.rpc_cred) || 416 (task->tk_msg.rpc_cred->cr_flags & RPCAUTH_CRED_UPTODATE); 417 } 418