1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7 #define pr_fmt(fmt) "AMD-Vi: " fmt 8 9 #include <linux/refcount.h> 10 #include <linux/mmu_notifier.h> 11 #include <linux/amd-iommu.h> 12 #include <linux/mm_types.h> 13 #include <linux/profile.h> 14 #include <linux/module.h> 15 #include <linux/sched.h> 16 #include <linux/sched/mm.h> 17 #include <linux/wait.h> 18 #include <linux/pci.h> 19 #include <linux/gfp.h> 20 #include <linux/cc_platform.h> 21 22 #include "amd_iommu.h" 23 24 MODULE_LICENSE("GPL v2"); 25 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 26 27 #define MAX_DEVICES 0x10000 28 #define PRI_QUEUE_SIZE 512 29 30 struct pri_queue { 31 atomic_t inflight; 32 bool finish; 33 int status; 34 }; 35 36 struct pasid_state { 37 struct list_head list; /* For global state-list */ 38 refcount_t count; /* Reference count */ 39 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 40 calls */ 41 struct mm_struct *mm; /* mm_struct for the faults */ 42 struct mmu_notifier mn; /* mmu_notifier handle */ 43 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 44 struct device_state *device_state; /* Link to our device_state */ 45 u32 pasid; /* PASID index */ 46 bool invalid; /* Used during setup and 47 teardown of the pasid */ 48 spinlock_t lock; /* Protect pri_queues and 49 mmu_notifer_count */ 50 wait_queue_head_t wq; /* To wait for count == 0 */ 51 }; 52 53 struct device_state { 54 struct list_head list; 55 u16 devid; 56 atomic_t count; 57 struct pci_dev *pdev; 58 struct pasid_state **states; 59 struct iommu_domain *domain; 60 int pasid_levels; 61 int max_pasids; 62 amd_iommu_invalid_ppr_cb inv_ppr_cb; 63 amd_iommu_invalidate_ctx inv_ctx_cb; 64 spinlock_t lock; 65 wait_queue_head_t wq; 66 }; 67 68 struct fault { 69 struct work_struct work; 70 struct device_state *dev_state; 71 struct pasid_state *state; 72 struct mm_struct *mm; 73 u64 address; 74 u16 devid; 75 u32 pasid; 76 u16 tag; 77 u16 finish; 78 u16 flags; 79 }; 80 81 static LIST_HEAD(state_list); 82 static DEFINE_SPINLOCK(state_lock); 83 84 static struct workqueue_struct *iommu_wq; 85 86 static void free_pasid_states(struct device_state *dev_state); 87 88 static u16 device_id(struct pci_dev *pdev) 89 { 90 u16 devid; 91 92 devid = pdev->bus->number; 93 devid = (devid << 8) | pdev->devfn; 94 95 return devid; 96 } 97 98 static struct device_state *__get_device_state(u16 devid) 99 { 100 struct device_state *dev_state; 101 102 list_for_each_entry(dev_state, &state_list, list) { 103 if (dev_state->devid == devid) 104 return dev_state; 105 } 106 107 return NULL; 108 } 109 110 static struct device_state *get_device_state(u16 devid) 111 { 112 struct device_state *dev_state; 113 unsigned long flags; 114 115 spin_lock_irqsave(&state_lock, flags); 116 dev_state = __get_device_state(devid); 117 if (dev_state != NULL) 118 atomic_inc(&dev_state->count); 119 spin_unlock_irqrestore(&state_lock, flags); 120 121 return dev_state; 122 } 123 124 static void free_device_state(struct device_state *dev_state) 125 { 126 struct iommu_group *group; 127 128 /* 129 * First detach device from domain - No more PRI requests will arrive 130 * from that device after it is unbound from the IOMMUv2 domain. 131 */ 132 group = iommu_group_get(&dev_state->pdev->dev); 133 if (WARN_ON(!group)) 134 return; 135 136 iommu_detach_group(dev_state->domain, group); 137 138 iommu_group_put(group); 139 140 /* Everything is down now, free the IOMMUv2 domain */ 141 iommu_domain_free(dev_state->domain); 142 143 /* Finally get rid of the device-state */ 144 kfree(dev_state); 145 } 146 147 static void put_device_state(struct device_state *dev_state) 148 { 149 if (atomic_dec_and_test(&dev_state->count)) 150 wake_up(&dev_state->wq); 151 } 152 153 /* Must be called under dev_state->lock */ 154 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 155 u32 pasid, bool alloc) 156 { 157 struct pasid_state **root, **ptr; 158 int level, index; 159 160 level = dev_state->pasid_levels; 161 root = dev_state->states; 162 163 while (true) { 164 165 index = (pasid >> (9 * level)) & 0x1ff; 166 ptr = &root[index]; 167 168 if (level == 0) 169 break; 170 171 if (*ptr == NULL) { 172 if (!alloc) 173 return NULL; 174 175 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 176 if (*ptr == NULL) 177 return NULL; 178 } 179 180 root = (struct pasid_state **)*ptr; 181 level -= 1; 182 } 183 184 return ptr; 185 } 186 187 static int set_pasid_state(struct device_state *dev_state, 188 struct pasid_state *pasid_state, 189 u32 pasid) 190 { 191 struct pasid_state **ptr; 192 unsigned long flags; 193 int ret; 194 195 spin_lock_irqsave(&dev_state->lock, flags); 196 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 197 198 ret = -ENOMEM; 199 if (ptr == NULL) 200 goto out_unlock; 201 202 ret = -ENOMEM; 203 if (*ptr != NULL) 204 goto out_unlock; 205 206 *ptr = pasid_state; 207 208 ret = 0; 209 210 out_unlock: 211 spin_unlock_irqrestore(&dev_state->lock, flags); 212 213 return ret; 214 } 215 216 static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 217 { 218 struct pasid_state **ptr; 219 unsigned long flags; 220 221 spin_lock_irqsave(&dev_state->lock, flags); 222 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 223 224 if (ptr == NULL) 225 goto out_unlock; 226 227 *ptr = NULL; 228 229 out_unlock: 230 spin_unlock_irqrestore(&dev_state->lock, flags); 231 } 232 233 static struct pasid_state *get_pasid_state(struct device_state *dev_state, 234 u32 pasid) 235 { 236 struct pasid_state **ptr, *ret = NULL; 237 unsigned long flags; 238 239 spin_lock_irqsave(&dev_state->lock, flags); 240 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 241 242 if (ptr == NULL) 243 goto out_unlock; 244 245 ret = *ptr; 246 if (ret) 247 refcount_inc(&ret->count); 248 249 out_unlock: 250 spin_unlock_irqrestore(&dev_state->lock, flags); 251 252 return ret; 253 } 254 255 static void free_pasid_state(struct pasid_state *pasid_state) 256 { 257 kfree(pasid_state); 258 } 259 260 static void put_pasid_state(struct pasid_state *pasid_state) 261 { 262 if (refcount_dec_and_test(&pasid_state->count)) 263 wake_up(&pasid_state->wq); 264 } 265 266 static void put_pasid_state_wait(struct pasid_state *pasid_state) 267 { 268 refcount_dec(&pasid_state->count); 269 wait_event(pasid_state->wq, !refcount_read(&pasid_state->count)); 270 free_pasid_state(pasid_state); 271 } 272 273 static void unbind_pasid(struct pasid_state *pasid_state) 274 { 275 struct iommu_domain *domain; 276 277 domain = pasid_state->device_state->domain; 278 279 /* 280 * Mark pasid_state as invalid, no more faults will we added to the 281 * work queue after this is visible everywhere. 282 */ 283 pasid_state->invalid = true; 284 285 /* Make sure this is visible */ 286 smp_wmb(); 287 288 /* After this the device/pasid can't access the mm anymore */ 289 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 290 291 /* Make sure no more pending faults are in the queue */ 292 flush_workqueue(iommu_wq); 293 } 294 295 static void free_pasid_states_level1(struct pasid_state **tbl) 296 { 297 int i; 298 299 for (i = 0; i < 512; ++i) { 300 if (tbl[i] == NULL) 301 continue; 302 303 free_page((unsigned long)tbl[i]); 304 } 305 } 306 307 static void free_pasid_states_level2(struct pasid_state **tbl) 308 { 309 struct pasid_state **ptr; 310 int i; 311 312 for (i = 0; i < 512; ++i) { 313 if (tbl[i] == NULL) 314 continue; 315 316 ptr = (struct pasid_state **)tbl[i]; 317 free_pasid_states_level1(ptr); 318 } 319 } 320 321 static void free_pasid_states(struct device_state *dev_state) 322 { 323 struct pasid_state *pasid_state; 324 int i; 325 326 for (i = 0; i < dev_state->max_pasids; ++i) { 327 pasid_state = get_pasid_state(dev_state, i); 328 if (pasid_state == NULL) 329 continue; 330 331 put_pasid_state(pasid_state); 332 333 /* 334 * This will call the mn_release function and 335 * unbind the PASID 336 */ 337 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 338 339 put_pasid_state_wait(pasid_state); /* Reference taken in 340 amd_iommu_bind_pasid */ 341 342 /* Drop reference taken in amd_iommu_bind_pasid */ 343 put_device_state(dev_state); 344 } 345 346 if (dev_state->pasid_levels == 2) 347 free_pasid_states_level2(dev_state->states); 348 else if (dev_state->pasid_levels == 1) 349 free_pasid_states_level1(dev_state->states); 350 else 351 BUG_ON(dev_state->pasid_levels != 0); 352 353 free_page((unsigned long)dev_state->states); 354 } 355 356 static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 357 { 358 return container_of(mn, struct pasid_state, mn); 359 } 360 361 static void mn_invalidate_range(struct mmu_notifier *mn, 362 struct mm_struct *mm, 363 unsigned long start, unsigned long end) 364 { 365 struct pasid_state *pasid_state; 366 struct device_state *dev_state; 367 368 pasid_state = mn_to_state(mn); 369 dev_state = pasid_state->device_state; 370 371 if ((start ^ (end - 1)) < PAGE_SIZE) 372 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 373 start); 374 else 375 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 376 } 377 378 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 379 { 380 struct pasid_state *pasid_state; 381 struct device_state *dev_state; 382 bool run_inv_ctx_cb; 383 384 might_sleep(); 385 386 pasid_state = mn_to_state(mn); 387 dev_state = pasid_state->device_state; 388 run_inv_ctx_cb = !pasid_state->invalid; 389 390 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 391 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 392 393 unbind_pasid(pasid_state); 394 } 395 396 static const struct mmu_notifier_ops iommu_mn = { 397 .release = mn_release, 398 .invalidate_range = mn_invalidate_range, 399 }; 400 401 static void set_pri_tag_status(struct pasid_state *pasid_state, 402 u16 tag, int status) 403 { 404 unsigned long flags; 405 406 spin_lock_irqsave(&pasid_state->lock, flags); 407 pasid_state->pri[tag].status = status; 408 spin_unlock_irqrestore(&pasid_state->lock, flags); 409 } 410 411 static void finish_pri_tag(struct device_state *dev_state, 412 struct pasid_state *pasid_state, 413 u16 tag) 414 { 415 unsigned long flags; 416 417 spin_lock_irqsave(&pasid_state->lock, flags); 418 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 419 pasid_state->pri[tag].finish) { 420 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 421 pasid_state->pri[tag].status, tag); 422 pasid_state->pri[tag].finish = false; 423 pasid_state->pri[tag].status = PPR_SUCCESS; 424 } 425 spin_unlock_irqrestore(&pasid_state->lock, flags); 426 } 427 428 static void handle_fault_error(struct fault *fault) 429 { 430 int status; 431 432 if (!fault->dev_state->inv_ppr_cb) { 433 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 434 return; 435 } 436 437 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 438 fault->pasid, 439 fault->address, 440 fault->flags); 441 switch (status) { 442 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 443 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 444 break; 445 case AMD_IOMMU_INV_PRI_RSP_INVALID: 446 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 447 break; 448 case AMD_IOMMU_INV_PRI_RSP_FAIL: 449 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 450 break; 451 default: 452 BUG(); 453 } 454 } 455 456 static bool access_error(struct vm_area_struct *vma, struct fault *fault) 457 { 458 unsigned long requested = 0; 459 460 if (fault->flags & PPR_FAULT_EXEC) 461 requested |= VM_EXEC; 462 463 if (fault->flags & PPR_FAULT_READ) 464 requested |= VM_READ; 465 466 if (fault->flags & PPR_FAULT_WRITE) 467 requested |= VM_WRITE; 468 469 return (requested & ~vma->vm_flags) != 0; 470 } 471 472 static void do_fault(struct work_struct *work) 473 { 474 struct fault *fault = container_of(work, struct fault, work); 475 struct vm_area_struct *vma; 476 vm_fault_t ret = VM_FAULT_ERROR; 477 unsigned int flags = 0; 478 struct mm_struct *mm; 479 u64 address; 480 481 mm = fault->state->mm; 482 address = fault->address; 483 484 if (fault->flags & PPR_FAULT_USER) 485 flags |= FAULT_FLAG_USER; 486 if (fault->flags & PPR_FAULT_WRITE) 487 flags |= FAULT_FLAG_WRITE; 488 flags |= FAULT_FLAG_REMOTE; 489 490 mmap_read_lock(mm); 491 vma = find_extend_vma(mm, address); 492 if (!vma || address < vma->vm_start) 493 /* failed to get a vma in the right range */ 494 goto out; 495 496 /* Check if we have the right permissions on the vma */ 497 if (access_error(vma, fault)) 498 goto out; 499 500 ret = handle_mm_fault(vma, address, flags, NULL); 501 out: 502 mmap_read_unlock(mm); 503 504 if (ret & VM_FAULT_ERROR) 505 /* failed to service fault */ 506 handle_fault_error(fault); 507 508 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 509 510 put_pasid_state(fault->state); 511 512 kfree(fault); 513 } 514 515 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 516 { 517 struct amd_iommu_fault *iommu_fault; 518 struct pasid_state *pasid_state; 519 struct device_state *dev_state; 520 struct pci_dev *pdev = NULL; 521 unsigned long flags; 522 struct fault *fault; 523 bool finish; 524 u16 tag, devid; 525 int ret; 526 527 iommu_fault = data; 528 tag = iommu_fault->tag & 0x1ff; 529 finish = (iommu_fault->tag >> 9) & 1; 530 531 devid = iommu_fault->device_id; 532 pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid), 533 devid & 0xff); 534 if (!pdev) 535 return -ENODEV; 536 537 ret = NOTIFY_DONE; 538 539 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 540 if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) { 541 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 542 PPR_INVALID, tag); 543 goto out; 544 } 545 546 dev_state = get_device_state(iommu_fault->device_id); 547 if (dev_state == NULL) 548 goto out; 549 550 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 551 if (pasid_state == NULL || pasid_state->invalid) { 552 /* We know the device but not the PASID -> send INVALID */ 553 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 554 PPR_INVALID, tag); 555 goto out_drop_state; 556 } 557 558 spin_lock_irqsave(&pasid_state->lock, flags); 559 atomic_inc(&pasid_state->pri[tag].inflight); 560 if (finish) 561 pasid_state->pri[tag].finish = true; 562 spin_unlock_irqrestore(&pasid_state->lock, flags); 563 564 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 565 if (fault == NULL) { 566 /* We are OOM - send success and let the device re-fault */ 567 finish_pri_tag(dev_state, pasid_state, tag); 568 goto out_drop_state; 569 } 570 571 fault->dev_state = dev_state; 572 fault->address = iommu_fault->address; 573 fault->state = pasid_state; 574 fault->tag = tag; 575 fault->finish = finish; 576 fault->pasid = iommu_fault->pasid; 577 fault->flags = iommu_fault->flags; 578 INIT_WORK(&fault->work, do_fault); 579 580 queue_work(iommu_wq, &fault->work); 581 582 ret = NOTIFY_OK; 583 584 out_drop_state: 585 586 if (ret != NOTIFY_OK && pasid_state) 587 put_pasid_state(pasid_state); 588 589 put_device_state(dev_state); 590 591 out: 592 return ret; 593 } 594 595 static struct notifier_block ppr_nb = { 596 .notifier_call = ppr_notifier, 597 }; 598 599 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 600 struct task_struct *task) 601 { 602 struct pasid_state *pasid_state; 603 struct device_state *dev_state; 604 struct mm_struct *mm; 605 u16 devid; 606 int ret; 607 608 might_sleep(); 609 610 if (!amd_iommu_v2_supported()) 611 return -ENODEV; 612 613 devid = device_id(pdev); 614 dev_state = get_device_state(devid); 615 616 if (dev_state == NULL) 617 return -EINVAL; 618 619 ret = -EINVAL; 620 if (pasid >= dev_state->max_pasids) 621 goto out; 622 623 ret = -ENOMEM; 624 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 625 if (pasid_state == NULL) 626 goto out; 627 628 629 refcount_set(&pasid_state->count, 1); 630 init_waitqueue_head(&pasid_state->wq); 631 spin_lock_init(&pasid_state->lock); 632 633 mm = get_task_mm(task); 634 pasid_state->mm = mm; 635 pasid_state->device_state = dev_state; 636 pasid_state->pasid = pasid; 637 pasid_state->invalid = true; /* Mark as valid only if we are 638 done with setting up the pasid */ 639 pasid_state->mn.ops = &iommu_mn; 640 641 if (pasid_state->mm == NULL) 642 goto out_free; 643 644 mmu_notifier_register(&pasid_state->mn, mm); 645 646 ret = set_pasid_state(dev_state, pasid_state, pasid); 647 if (ret) 648 goto out_unregister; 649 650 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 651 __pa(pasid_state->mm->pgd)); 652 if (ret) 653 goto out_clear_state; 654 655 /* Now we are ready to handle faults */ 656 pasid_state->invalid = false; 657 658 /* 659 * Drop the reference to the mm_struct here. We rely on the 660 * mmu_notifier release call-back to inform us when the mm 661 * is going away. 662 */ 663 mmput(mm); 664 665 return 0; 666 667 out_clear_state: 668 clear_pasid_state(dev_state, pasid); 669 670 out_unregister: 671 mmu_notifier_unregister(&pasid_state->mn, mm); 672 mmput(mm); 673 674 out_free: 675 free_pasid_state(pasid_state); 676 677 out: 678 put_device_state(dev_state); 679 680 return ret; 681 } 682 EXPORT_SYMBOL(amd_iommu_bind_pasid); 683 684 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 685 { 686 struct pasid_state *pasid_state; 687 struct device_state *dev_state; 688 u16 devid; 689 690 might_sleep(); 691 692 if (!amd_iommu_v2_supported()) 693 return; 694 695 devid = device_id(pdev); 696 dev_state = get_device_state(devid); 697 if (dev_state == NULL) 698 return; 699 700 if (pasid >= dev_state->max_pasids) 701 goto out; 702 703 pasid_state = get_pasid_state(dev_state, pasid); 704 if (pasid_state == NULL) 705 goto out; 706 /* 707 * Drop reference taken here. We are safe because we still hold 708 * the reference taken in the amd_iommu_bind_pasid function. 709 */ 710 put_pasid_state(pasid_state); 711 712 /* Clear the pasid state so that the pasid can be re-used */ 713 clear_pasid_state(dev_state, pasid_state->pasid); 714 715 /* 716 * Call mmu_notifier_unregister to drop our reference 717 * to pasid_state->mm 718 */ 719 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 720 721 put_pasid_state_wait(pasid_state); /* Reference taken in 722 amd_iommu_bind_pasid */ 723 out: 724 /* Drop reference taken in this function */ 725 put_device_state(dev_state); 726 727 /* Drop reference taken in amd_iommu_bind_pasid */ 728 put_device_state(dev_state); 729 } 730 EXPORT_SYMBOL(amd_iommu_unbind_pasid); 731 732 int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 733 { 734 struct device_state *dev_state; 735 struct iommu_group *group; 736 unsigned long flags; 737 int ret, tmp; 738 u16 devid; 739 740 might_sleep(); 741 742 /* 743 * When memory encryption is active the device is likely not in a 744 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 745 */ 746 if (cc_platform_has(CC_ATTR_MEM_ENCRYPT)) 747 return -ENODEV; 748 749 if (!amd_iommu_v2_supported()) 750 return -ENODEV; 751 752 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 753 return -EINVAL; 754 755 devid = device_id(pdev); 756 757 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 758 if (dev_state == NULL) 759 return -ENOMEM; 760 761 spin_lock_init(&dev_state->lock); 762 init_waitqueue_head(&dev_state->wq); 763 dev_state->pdev = pdev; 764 dev_state->devid = devid; 765 766 tmp = pasids; 767 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 768 dev_state->pasid_levels += 1; 769 770 atomic_set(&dev_state->count, 1); 771 dev_state->max_pasids = pasids; 772 773 ret = -ENOMEM; 774 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 775 if (dev_state->states == NULL) 776 goto out_free_dev_state; 777 778 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 779 if (dev_state->domain == NULL) 780 goto out_free_states; 781 782 amd_iommu_domain_direct_map(dev_state->domain); 783 784 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 785 if (ret) 786 goto out_free_domain; 787 788 group = iommu_group_get(&pdev->dev); 789 if (!group) { 790 ret = -EINVAL; 791 goto out_free_domain; 792 } 793 794 ret = iommu_attach_group(dev_state->domain, group); 795 if (ret != 0) 796 goto out_drop_group; 797 798 iommu_group_put(group); 799 800 spin_lock_irqsave(&state_lock, flags); 801 802 if (__get_device_state(devid) != NULL) { 803 spin_unlock_irqrestore(&state_lock, flags); 804 ret = -EBUSY; 805 goto out_free_domain; 806 } 807 808 list_add_tail(&dev_state->list, &state_list); 809 810 spin_unlock_irqrestore(&state_lock, flags); 811 812 return 0; 813 814 out_drop_group: 815 iommu_group_put(group); 816 817 out_free_domain: 818 iommu_domain_free(dev_state->domain); 819 820 out_free_states: 821 free_page((unsigned long)dev_state->states); 822 823 out_free_dev_state: 824 kfree(dev_state); 825 826 return ret; 827 } 828 EXPORT_SYMBOL(amd_iommu_init_device); 829 830 void amd_iommu_free_device(struct pci_dev *pdev) 831 { 832 struct device_state *dev_state; 833 unsigned long flags; 834 u16 devid; 835 836 if (!amd_iommu_v2_supported()) 837 return; 838 839 devid = device_id(pdev); 840 841 spin_lock_irqsave(&state_lock, flags); 842 843 dev_state = __get_device_state(devid); 844 if (dev_state == NULL) { 845 spin_unlock_irqrestore(&state_lock, flags); 846 return; 847 } 848 849 list_del(&dev_state->list); 850 851 spin_unlock_irqrestore(&state_lock, flags); 852 853 /* Get rid of any remaining pasid states */ 854 free_pasid_states(dev_state); 855 856 put_device_state(dev_state); 857 /* 858 * Wait until the last reference is dropped before freeing 859 * the device state. 860 */ 861 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 862 free_device_state(dev_state); 863 } 864 EXPORT_SYMBOL(amd_iommu_free_device); 865 866 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 867 amd_iommu_invalid_ppr_cb cb) 868 { 869 struct device_state *dev_state; 870 unsigned long flags; 871 u16 devid; 872 int ret; 873 874 if (!amd_iommu_v2_supported()) 875 return -ENODEV; 876 877 devid = device_id(pdev); 878 879 spin_lock_irqsave(&state_lock, flags); 880 881 ret = -EINVAL; 882 dev_state = __get_device_state(devid); 883 if (dev_state == NULL) 884 goto out_unlock; 885 886 dev_state->inv_ppr_cb = cb; 887 888 ret = 0; 889 890 out_unlock: 891 spin_unlock_irqrestore(&state_lock, flags); 892 893 return ret; 894 } 895 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 896 897 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 898 amd_iommu_invalidate_ctx cb) 899 { 900 struct device_state *dev_state; 901 unsigned long flags; 902 u16 devid; 903 int ret; 904 905 if (!amd_iommu_v2_supported()) 906 return -ENODEV; 907 908 devid = device_id(pdev); 909 910 spin_lock_irqsave(&state_lock, flags); 911 912 ret = -EINVAL; 913 dev_state = __get_device_state(devid); 914 if (dev_state == NULL) 915 goto out_unlock; 916 917 dev_state->inv_ctx_cb = cb; 918 919 ret = 0; 920 921 out_unlock: 922 spin_unlock_irqrestore(&state_lock, flags); 923 924 return ret; 925 } 926 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 927 928 static int __init amd_iommu_v2_init(void) 929 { 930 int ret; 931 932 if (!amd_iommu_v2_supported()) { 933 pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n"); 934 /* 935 * Load anyway to provide the symbols to other modules 936 * which may use AMD IOMMUv2 optionally. 937 */ 938 return 0; 939 } 940 941 ret = -ENOMEM; 942 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 943 if (iommu_wq == NULL) 944 goto out; 945 946 amd_iommu_register_ppr_notifier(&ppr_nb); 947 948 pr_info("AMD IOMMUv2 loaded and initialized\n"); 949 950 return 0; 951 952 out: 953 return ret; 954 } 955 956 static void __exit amd_iommu_v2_exit(void) 957 { 958 struct device_state *dev_state; 959 int i; 960 961 if (!amd_iommu_v2_supported()) 962 return; 963 964 amd_iommu_unregister_ppr_notifier(&ppr_nb); 965 966 flush_workqueue(iommu_wq); 967 968 /* 969 * The loop below might call flush_workqueue(), so call 970 * destroy_workqueue() after it 971 */ 972 for (i = 0; i < MAX_DEVICES; ++i) { 973 dev_state = get_device_state(i); 974 975 if (dev_state == NULL) 976 continue; 977 978 WARN_ON_ONCE(1); 979 980 put_device_state(dev_state); 981 amd_iommu_free_device(dev_state->pdev); 982 } 983 984 destroy_workqueue(iommu_wq); 985 } 986 987 module_init(amd_iommu_v2_init); 988 module_exit(amd_iommu_v2_exit); 989