1 /* 2 * linux/net/sunrpc/auth.c 3 * 4 * Generic RPC client authentication API. 5 * 6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> 7 */ 8 9 #include <linux/types.h> 10 #include <linux/sched.h> 11 #include <linux/cred.h> 12 #include <linux/module.h> 13 #include <linux/slab.h> 14 #include <linux/errno.h> 15 #include <linux/hash.h> 16 #include <linux/sunrpc/clnt.h> 17 #include <linux/sunrpc/gss_api.h> 18 #include <linux/spinlock.h> 19 20 #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) 21 # define RPCDBG_FACILITY RPCDBG_AUTH 22 #endif 23 24 #define RPC_CREDCACHE_DEFAULT_HASHBITS (4) 25 struct rpc_cred_cache { 26 struct hlist_head *hashtable; 27 unsigned int hashbits; 28 spinlock_t lock; 29 }; 30 31 static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; 32 33 static DEFINE_SPINLOCK(rpc_authflavor_lock); 34 static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { 35 &authnull_ops, /* AUTH_NULL */ 36 &authunix_ops, /* AUTH_UNIX */ 37 NULL, /* others can be loadable modules */ 38 }; 39 40 static LIST_HEAD(cred_unused); 41 static unsigned long number_cred_unused; 42 43 #define MAX_HASHTABLE_BITS (14) 44 static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp) 45 { 46 unsigned long num; 47 unsigned int nbits; 48 int ret; 49 50 if (!val) 51 goto out_inval; 52 ret = kstrtoul(val, 0, &num); 53 if (ret == -EINVAL) 54 goto out_inval; 55 nbits = fls(num - 1); 56 if (nbits > MAX_HASHTABLE_BITS || nbits < 2) 57 goto out_inval; 58 *(unsigned int *)kp->arg = nbits; 59 return 0; 60 out_inval: 61 return -EINVAL; 62 } 63 64 static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp) 65 { 66 unsigned int nbits; 67 68 nbits = *(unsigned int *)kp->arg; 69 return sprintf(buffer, "%u", 1U << nbits); 70 } 71 72 #define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int); 73 74 static const struct kernel_param_ops param_ops_hashtbl_sz = { 75 .set = param_set_hashtbl_sz, 76 .get = param_get_hashtbl_sz, 77 }; 78 79 module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644); 80 MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size"); 81 82 static unsigned long auth_max_cred_cachesize = ULONG_MAX; 83 module_param(auth_max_cred_cachesize, ulong, 0644); 84 MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size"); 85 86 static u32 87 pseudoflavor_to_flavor(u32 flavor) { 88 if (flavor > RPC_AUTH_MAXFLAVOR) 89 return RPC_AUTH_GSS; 90 return flavor; 91 } 92 93 int 94 rpcauth_register(const struct rpc_authops *ops) 95 { 96 rpc_authflavor_t flavor; 97 int ret = -EPERM; 98 99 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 100 return -EINVAL; 101 spin_lock(&rpc_authflavor_lock); 102 if (auth_flavors[flavor] == NULL) { 103 auth_flavors[flavor] = ops; 104 ret = 0; 105 } 106 spin_unlock(&rpc_authflavor_lock); 107 return ret; 108 } 109 EXPORT_SYMBOL_GPL(rpcauth_register); 110 111 int 112 rpcauth_unregister(const struct rpc_authops *ops) 113 { 114 rpc_authflavor_t flavor; 115 int ret = -EPERM; 116 117 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 118 return -EINVAL; 119 spin_lock(&rpc_authflavor_lock); 120 if (auth_flavors[flavor] == ops) { 121 auth_flavors[flavor] = NULL; 122 ret = 0; 123 } 124 spin_unlock(&rpc_authflavor_lock); 125 return ret; 126 } 127 EXPORT_SYMBOL_GPL(rpcauth_unregister); 128 129 /** 130 * rpcauth_get_pseudoflavor - check if security flavor is supported 131 * @flavor: a security flavor 132 * @info: a GSS mech OID, quality of protection, and service value 133 * 134 * Verifies that an appropriate kernel module is available or already loaded. 135 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is 136 * not supported locally. 137 */ 138 rpc_authflavor_t 139 rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) 140 { 141 const struct rpc_authops *ops; 142 rpc_authflavor_t pseudoflavor; 143 144 ops = auth_flavors[flavor]; 145 if (ops == NULL) 146 request_module("rpc-auth-%u", flavor); 147 spin_lock(&rpc_authflavor_lock); 148 ops = auth_flavors[flavor]; 149 if (ops == NULL || !try_module_get(ops->owner)) { 150 spin_unlock(&rpc_authflavor_lock); 151 return RPC_AUTH_MAXFLAVOR; 152 } 153 spin_unlock(&rpc_authflavor_lock); 154 155 pseudoflavor = flavor; 156 if (ops->info2flavor != NULL) 157 pseudoflavor = ops->info2flavor(info); 158 159 module_put(ops->owner); 160 return pseudoflavor; 161 } 162 EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); 163 164 /** 165 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor 166 * @pseudoflavor: GSS pseudoflavor to match 167 * @info: rpcsec_gss_info structure to fill in 168 * 169 * Returns zero and fills in "info" if pseudoflavor matches a 170 * supported mechanism. 171 */ 172 int 173 rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) 174 { 175 rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor); 176 const struct rpc_authops *ops; 177 int result; 178 179 if (flavor >= RPC_AUTH_MAXFLAVOR) 180 return -EINVAL; 181 182 ops = auth_flavors[flavor]; 183 if (ops == NULL) 184 request_module("rpc-auth-%u", flavor); 185 spin_lock(&rpc_authflavor_lock); 186 ops = auth_flavors[flavor]; 187 if (ops == NULL || !try_module_get(ops->owner)) { 188 spin_unlock(&rpc_authflavor_lock); 189 return -ENOENT; 190 } 191 spin_unlock(&rpc_authflavor_lock); 192 193 result = -ENOENT; 194 if (ops->flavor2info != NULL) 195 result = ops->flavor2info(pseudoflavor, info); 196 197 module_put(ops->owner); 198 return result; 199 } 200 EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); 201 202 /** 203 * rpcauth_list_flavors - discover registered flavors and pseudoflavors 204 * @array: array to fill in 205 * @size: size of "array" 206 * 207 * Returns the number of array items filled in, or a negative errno. 208 * 209 * The returned array is not sorted by any policy. Callers should not 210 * rely on the order of the items in the returned array. 211 */ 212 int 213 rpcauth_list_flavors(rpc_authflavor_t *array, int size) 214 { 215 rpc_authflavor_t flavor; 216 int result = 0; 217 218 spin_lock(&rpc_authflavor_lock); 219 for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) { 220 const struct rpc_authops *ops = auth_flavors[flavor]; 221 rpc_authflavor_t pseudos[4]; 222 int i, len; 223 224 if (result >= size) { 225 result = -ENOMEM; 226 break; 227 } 228 229 if (ops == NULL) 230 continue; 231 if (ops->list_pseudoflavors == NULL) { 232 array[result++] = ops->au_flavor; 233 continue; 234 } 235 len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos)); 236 if (len < 0) { 237 result = len; 238 break; 239 } 240 for (i = 0; i < len; i++) { 241 if (result >= size) { 242 result = -ENOMEM; 243 break; 244 } 245 array[result++] = pseudos[i]; 246 } 247 } 248 spin_unlock(&rpc_authflavor_lock); 249 250 dprintk("RPC: %s returns %d\n", __func__, result); 251 return result; 252 } 253 EXPORT_SYMBOL_GPL(rpcauth_list_flavors); 254 255 struct rpc_auth * 256 rpcauth_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt) 257 { 258 struct rpc_auth *auth; 259 const struct rpc_authops *ops; 260 u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor); 261 262 auth = ERR_PTR(-EINVAL); 263 if (flavor >= RPC_AUTH_MAXFLAVOR) 264 goto out; 265 266 if ((ops = auth_flavors[flavor]) == NULL) 267 request_module("rpc-auth-%u", flavor); 268 spin_lock(&rpc_authflavor_lock); 269 ops = auth_flavors[flavor]; 270 if (ops == NULL || !try_module_get(ops->owner)) { 271 spin_unlock(&rpc_authflavor_lock); 272 goto out; 273 } 274 spin_unlock(&rpc_authflavor_lock); 275 auth = ops->create(args, clnt); 276 module_put(ops->owner); 277 if (IS_ERR(auth)) 278 return auth; 279 if (clnt->cl_auth) 280 rpcauth_release(clnt->cl_auth); 281 clnt->cl_auth = auth; 282 283 out: 284 return auth; 285 } 286 EXPORT_SYMBOL_GPL(rpcauth_create); 287 288 void 289 rpcauth_release(struct rpc_auth *auth) 290 { 291 if (!atomic_dec_and_test(&auth->au_count)) 292 return; 293 auth->au_ops->destroy(auth); 294 } 295 296 static DEFINE_SPINLOCK(rpc_credcache_lock); 297 298 static void 299 rpcauth_unhash_cred_locked(struct rpc_cred *cred) 300 { 301 hlist_del_rcu(&cred->cr_hash); 302 smp_mb__before_atomic(); 303 clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); 304 } 305 306 static int 307 rpcauth_unhash_cred(struct rpc_cred *cred) 308 { 309 spinlock_t *cache_lock; 310 int ret; 311 312 cache_lock = &cred->cr_auth->au_credcache->lock; 313 spin_lock(cache_lock); 314 ret = atomic_read(&cred->cr_count) == 0; 315 if (ret) 316 rpcauth_unhash_cred_locked(cred); 317 spin_unlock(cache_lock); 318 return ret; 319 } 320 321 /* 322 * Initialize RPC credential cache 323 */ 324 int 325 rpcauth_init_credcache(struct rpc_auth *auth) 326 { 327 struct rpc_cred_cache *new; 328 unsigned int hashsize; 329 330 new = kmalloc(sizeof(*new), GFP_KERNEL); 331 if (!new) 332 goto out_nocache; 333 new->hashbits = auth_hashbits; 334 hashsize = 1U << new->hashbits; 335 new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL); 336 if (!new->hashtable) 337 goto out_nohashtbl; 338 spin_lock_init(&new->lock); 339 auth->au_credcache = new; 340 return 0; 341 out_nohashtbl: 342 kfree(new); 343 out_nocache: 344 return -ENOMEM; 345 } 346 EXPORT_SYMBOL_GPL(rpcauth_init_credcache); 347 348 /* 349 * Setup a credential key lifetime timeout notification 350 */ 351 int 352 rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred) 353 { 354 if (!cred->cr_auth->au_ops->key_timeout) 355 return 0; 356 return cred->cr_auth->au_ops->key_timeout(auth, cred); 357 } 358 EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify); 359 360 bool 361 rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred) 362 { 363 if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT) 364 return false; 365 if (!cred->cr_ops->crkey_to_expire) 366 return false; 367 return cred->cr_ops->crkey_to_expire(cred); 368 } 369 EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire); 370 371 char * 372 rpcauth_stringify_acceptor(struct rpc_cred *cred) 373 { 374 if (!cred->cr_ops->crstringify_acceptor) 375 return NULL; 376 return cred->cr_ops->crstringify_acceptor(cred); 377 } 378 EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor); 379 380 /* 381 * Destroy a list of credentials 382 */ 383 static inline 384 void rpcauth_destroy_credlist(struct list_head *head) 385 { 386 struct rpc_cred *cred; 387 388 while (!list_empty(head)) { 389 cred = list_entry(head->next, struct rpc_cred, cr_lru); 390 list_del_init(&cred->cr_lru); 391 put_rpccred(cred); 392 } 393 } 394 395 /* 396 * Clear the RPC credential cache, and delete those credentials 397 * that are not referenced. 398 */ 399 void 400 rpcauth_clear_credcache(struct rpc_cred_cache *cache) 401 { 402 LIST_HEAD(free); 403 struct hlist_head *head; 404 struct rpc_cred *cred; 405 unsigned int hashsize = 1U << cache->hashbits; 406 int i; 407 408 spin_lock(&rpc_credcache_lock); 409 spin_lock(&cache->lock); 410 for (i = 0; i < hashsize; i++) { 411 head = &cache->hashtable[i]; 412 while (!hlist_empty(head)) { 413 cred = hlist_entry(head->first, struct rpc_cred, cr_hash); 414 get_rpccred(cred); 415 if (!list_empty(&cred->cr_lru)) { 416 list_del(&cred->cr_lru); 417 number_cred_unused--; 418 } 419 list_add_tail(&cred->cr_lru, &free); 420 rpcauth_unhash_cred_locked(cred); 421 } 422 } 423 spin_unlock(&cache->lock); 424 spin_unlock(&rpc_credcache_lock); 425 rpcauth_destroy_credlist(&free); 426 } 427 428 /* 429 * Destroy the RPC credential cache 430 */ 431 void 432 rpcauth_destroy_credcache(struct rpc_auth *auth) 433 { 434 struct rpc_cred_cache *cache = auth->au_credcache; 435 436 if (cache) { 437 auth->au_credcache = NULL; 438 rpcauth_clear_credcache(cache); 439 kfree(cache->hashtable); 440 kfree(cache); 441 } 442 } 443 EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); 444 445 446 #define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ) 447 448 /* 449 * Remove stale credentials. Avoid sleeping inside the loop. 450 */ 451 static long 452 rpcauth_prune_expired(struct list_head *free, int nr_to_scan) 453 { 454 spinlock_t *cache_lock; 455 struct rpc_cred *cred, *next; 456 unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; 457 long freed = 0; 458 459 list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { 460 461 if (nr_to_scan-- == 0) 462 break; 463 /* 464 * Enforce a 60 second garbage collection moratorium 465 * Note that the cred_unused list must be time-ordered. 466 */ 467 if (time_in_range(cred->cr_expire, expired, jiffies) && 468 test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { 469 freed = SHRINK_STOP; 470 break; 471 } 472 473 list_del_init(&cred->cr_lru); 474 number_cred_unused--; 475 freed++; 476 if (atomic_read(&cred->cr_count) != 0) 477 continue; 478 479 cache_lock = &cred->cr_auth->au_credcache->lock; 480 spin_lock(cache_lock); 481 if (atomic_read(&cred->cr_count) == 0) { 482 get_rpccred(cred); 483 list_add_tail(&cred->cr_lru, free); 484 rpcauth_unhash_cred_locked(cred); 485 } 486 spin_unlock(cache_lock); 487 } 488 return freed; 489 } 490 491 static unsigned long 492 rpcauth_cache_do_shrink(int nr_to_scan) 493 { 494 LIST_HEAD(free); 495 unsigned long freed; 496 497 spin_lock(&rpc_credcache_lock); 498 freed = rpcauth_prune_expired(&free, nr_to_scan); 499 spin_unlock(&rpc_credcache_lock); 500 rpcauth_destroy_credlist(&free); 501 502 return freed; 503 } 504 505 /* 506 * Run memory cache shrinker. 507 */ 508 static unsigned long 509 rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc) 510 511 { 512 if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL) 513 return SHRINK_STOP; 514 515 /* nothing left, don't come back */ 516 if (list_empty(&cred_unused)) 517 return SHRINK_STOP; 518 519 return rpcauth_cache_do_shrink(sc->nr_to_scan); 520 } 521 522 static unsigned long 523 rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc) 524 525 { 526 return number_cred_unused * sysctl_vfs_cache_pressure / 100; 527 } 528 529 static void 530 rpcauth_cache_enforce_limit(void) 531 { 532 unsigned long diff; 533 unsigned int nr_to_scan; 534 535 if (number_cred_unused <= auth_max_cred_cachesize) 536 return; 537 diff = number_cred_unused - auth_max_cred_cachesize; 538 nr_to_scan = 100; 539 if (diff < nr_to_scan) 540 nr_to_scan = diff; 541 rpcauth_cache_do_shrink(nr_to_scan); 542 } 543 544 /* 545 * Look up a process' credentials in the authentication cache 546 */ 547 struct rpc_cred * 548 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, 549 int flags, gfp_t gfp) 550 { 551 LIST_HEAD(free); 552 struct rpc_cred_cache *cache = auth->au_credcache; 553 struct rpc_cred *cred = NULL, 554 *entry, *new; 555 unsigned int nr; 556 557 nr = auth->au_ops->hash_cred(acred, cache->hashbits); 558 559 rcu_read_lock(); 560 hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) { 561 if (!entry->cr_ops->crmatch(acred, entry, flags)) 562 continue; 563 if (flags & RPCAUTH_LOOKUP_RCU) { 564 if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) && 565 !test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags)) 566 cred = entry; 567 break; 568 } 569 spin_lock(&cache->lock); 570 if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) { 571 spin_unlock(&cache->lock); 572 continue; 573 } 574 cred = get_rpccred(entry); 575 spin_unlock(&cache->lock); 576 break; 577 } 578 rcu_read_unlock(); 579 580 if (cred != NULL) 581 goto found; 582 583 if (flags & RPCAUTH_LOOKUP_RCU) 584 return ERR_PTR(-ECHILD); 585 586 new = auth->au_ops->crcreate(auth, acred, flags, gfp); 587 if (IS_ERR(new)) { 588 cred = new; 589 goto out; 590 } 591 592 spin_lock(&cache->lock); 593 hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) { 594 if (!entry->cr_ops->crmatch(acred, entry, flags)) 595 continue; 596 cred = get_rpccred(entry); 597 break; 598 } 599 if (cred == NULL) { 600 cred = new; 601 set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); 602 hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]); 603 } else 604 list_add_tail(&new->cr_lru, &free); 605 spin_unlock(&cache->lock); 606 rpcauth_cache_enforce_limit(); 607 found: 608 if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && 609 cred->cr_ops->cr_init != NULL && 610 !(flags & RPCAUTH_LOOKUP_NEW)) { 611 int res = cred->cr_ops->cr_init(auth, cred); 612 if (res < 0) { 613 put_rpccred(cred); 614 cred = ERR_PTR(res); 615 } 616 } 617 rpcauth_destroy_credlist(&free); 618 out: 619 return cred; 620 } 621 EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache); 622 623 struct rpc_cred * 624 rpcauth_lookupcred(struct rpc_auth *auth, int flags) 625 { 626 struct auth_cred acred; 627 struct rpc_cred *ret; 628 const struct cred *cred = current_cred(); 629 630 dprintk("RPC: looking up %s cred\n", 631 auth->au_ops->au_name); 632 633 memset(&acred, 0, sizeof(acred)); 634 acred.uid = cred->fsuid; 635 acred.gid = cred->fsgid; 636 acred.group_info = cred->group_info; 637 ret = auth->au_ops->lookup_cred(auth, &acred, flags); 638 return ret; 639 } 640 EXPORT_SYMBOL_GPL(rpcauth_lookupcred); 641 642 void 643 rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, 644 struct rpc_auth *auth, const struct rpc_credops *ops) 645 { 646 INIT_HLIST_NODE(&cred->cr_hash); 647 INIT_LIST_HEAD(&cred->cr_lru); 648 atomic_set(&cred->cr_count, 1); 649 cred->cr_auth = auth; 650 cred->cr_ops = ops; 651 cred->cr_expire = jiffies; 652 cred->cr_uid = acred->uid; 653 } 654 EXPORT_SYMBOL_GPL(rpcauth_init_cred); 655 656 struct rpc_cred * 657 rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags) 658 { 659 dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid, 660 cred->cr_auth->au_ops->au_name, cred); 661 return get_rpccred(cred); 662 } 663 EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred); 664 665 static struct rpc_cred * 666 rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) 667 { 668 struct rpc_auth *auth = task->tk_client->cl_auth; 669 struct auth_cred acred = { 670 .uid = GLOBAL_ROOT_UID, 671 .gid = GLOBAL_ROOT_GID, 672 }; 673 674 dprintk("RPC: %5u looking up %s cred\n", 675 task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); 676 return auth->au_ops->lookup_cred(auth, &acred, lookupflags); 677 } 678 679 static struct rpc_cred * 680 rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags) 681 { 682 struct rpc_auth *auth = task->tk_client->cl_auth; 683 684 dprintk("RPC: %5u looking up %s cred\n", 685 task->tk_pid, auth->au_ops->au_name); 686 return rpcauth_lookupcred(auth, lookupflags); 687 } 688 689 static int 690 rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags) 691 { 692 struct rpc_rqst *req = task->tk_rqstp; 693 struct rpc_cred *new; 694 int lookupflags = 0; 695 696 if (flags & RPC_TASK_ASYNC) 697 lookupflags |= RPCAUTH_LOOKUP_NEW; 698 if (cred != NULL) 699 new = cred->cr_ops->crbind(task, cred, lookupflags); 700 else if (flags & RPC_TASK_ROOTCREDS) 701 new = rpcauth_bind_root_cred(task, lookupflags); 702 else 703 new = rpcauth_bind_new_cred(task, lookupflags); 704 if (IS_ERR(new)) 705 return PTR_ERR(new); 706 put_rpccred(req->rq_cred); 707 req->rq_cred = new; 708 return 0; 709 } 710 711 void 712 put_rpccred(struct rpc_cred *cred) 713 { 714 if (cred == NULL) 715 return; 716 /* Fast path for unhashed credentials */ 717 if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) { 718 if (atomic_dec_and_test(&cred->cr_count)) 719 cred->cr_ops->crdestroy(cred); 720 return; 721 } 722 723 if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock)) 724 return; 725 if (!list_empty(&cred->cr_lru)) { 726 number_cred_unused--; 727 list_del_init(&cred->cr_lru); 728 } 729 if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { 730 if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { 731 cred->cr_expire = jiffies; 732 list_add_tail(&cred->cr_lru, &cred_unused); 733 number_cred_unused++; 734 goto out_nodestroy; 735 } 736 if (!rpcauth_unhash_cred(cred)) { 737 /* We were hashed and someone looked us up... */ 738 goto out_nodestroy; 739 } 740 } 741 spin_unlock(&rpc_credcache_lock); 742 cred->cr_ops->crdestroy(cred); 743 return; 744 out_nodestroy: 745 spin_unlock(&rpc_credcache_lock); 746 } 747 EXPORT_SYMBOL_GPL(put_rpccred); 748 749 __be32 * 750 rpcauth_marshcred(struct rpc_task *task, __be32 *p) 751 { 752 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 753 754 dprintk("RPC: %5u marshaling %s cred %p\n", 755 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 756 757 return cred->cr_ops->crmarshal(task, p); 758 } 759 760 __be32 * 761 rpcauth_checkverf(struct rpc_task *task, __be32 *p) 762 { 763 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 764 765 dprintk("RPC: %5u validating %s cred %p\n", 766 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 767 768 return cred->cr_ops->crvalidate(task, p); 769 } 770 771 static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, 772 __be32 *data, void *obj) 773 { 774 struct xdr_stream xdr; 775 776 xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data); 777 encode(rqstp, &xdr, obj); 778 } 779 780 int 781 rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp, 782 __be32 *data, void *obj) 783 { 784 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 785 786 dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", 787 task->tk_pid, cred->cr_ops->cr_name, cred); 788 if (cred->cr_ops->crwrap_req) 789 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); 790 /* By default, we encode the arguments normally. */ 791 rpcauth_wrap_req_encode(encode, rqstp, data, obj); 792 return 0; 793 } 794 795 static int 796 rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, 797 __be32 *data, void *obj) 798 { 799 struct xdr_stream xdr; 800 801 xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data); 802 return decode(rqstp, &xdr, obj); 803 } 804 805 int 806 rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, 807 __be32 *data, void *obj) 808 { 809 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 810 811 dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", 812 task->tk_pid, cred->cr_ops->cr_name, cred); 813 if (cred->cr_ops->crunwrap_resp) 814 return cred->cr_ops->crunwrap_resp(task, decode, rqstp, 815 data, obj); 816 /* By default, we decode the arguments normally. */ 817 return rpcauth_unwrap_req_decode(decode, rqstp, data, obj); 818 } 819 820 int 821 rpcauth_refreshcred(struct rpc_task *task) 822 { 823 struct rpc_cred *cred; 824 int err; 825 826 cred = task->tk_rqstp->rq_cred; 827 if (cred == NULL) { 828 err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags); 829 if (err < 0) 830 goto out; 831 cred = task->tk_rqstp->rq_cred; 832 } 833 dprintk("RPC: %5u refreshing %s cred %p\n", 834 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 835 836 err = cred->cr_ops->crrefresh(task); 837 out: 838 if (err < 0) 839 task->tk_status = err; 840 return err; 841 } 842 843 void 844 rpcauth_invalcred(struct rpc_task *task) 845 { 846 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 847 848 dprintk("RPC: %5u invalidating %s cred %p\n", 849 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 850 if (cred) 851 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); 852 } 853 854 int 855 rpcauth_uptodatecred(struct rpc_task *task) 856 { 857 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 858 859 return cred == NULL || 860 test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0; 861 } 862 863 static struct shrinker rpc_cred_shrinker = { 864 .count_objects = rpcauth_cache_shrink_count, 865 .scan_objects = rpcauth_cache_shrink_scan, 866 .seeks = DEFAULT_SEEKS, 867 }; 868 869 int __init rpcauth_init_module(void) 870 { 871 int err; 872 873 err = rpc_init_authunix(); 874 if (err < 0) 875 goto out1; 876 err = rpc_init_generic_auth(); 877 if (err < 0) 878 goto out2; 879 err = register_shrinker(&rpc_cred_shrinker); 880 if (err < 0) 881 goto out3; 882 return 0; 883 out3: 884 rpc_destroy_generic_auth(); 885 out2: 886 rpc_destroy_authunix(); 887 out1: 888 return err; 889 } 890 891 void rpcauth_remove_module(void) 892 { 893 rpc_destroy_authunix(); 894 rpc_destroy_generic_auth(); 895 unregister_shrinker(&rpc_cred_shrinker); 896 } 897