1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7 #define pr_fmt(fmt) "AMD-Vi: " fmt 8 9 #include <linux/refcount.h> 10 #include <linux/mmu_notifier.h> 11 #include <linux/amd-iommu.h> 12 #include <linux/mm_types.h> 13 #include <linux/profile.h> 14 #include <linux/module.h> 15 #include <linux/sched.h> 16 #include <linux/sched/mm.h> 17 #include <linux/wait.h> 18 #include <linux/pci.h> 19 #include <linux/gfp.h> 20 #include <linux/cc_platform.h> 21 22 #include "amd_iommu.h" 23 24 MODULE_LICENSE("GPL v2"); 25 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 26 27 #define PRI_QUEUE_SIZE 512 28 29 struct pri_queue { 30 atomic_t inflight; 31 bool finish; 32 int status; 33 }; 34 35 struct pasid_state { 36 struct list_head list; /* For global state-list */ 37 refcount_t count; /* Reference count */ 38 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 39 calls */ 40 struct mm_struct *mm; /* mm_struct for the faults */ 41 struct mmu_notifier mn; /* mmu_notifier handle */ 42 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 43 struct device_state *device_state; /* Link to our device_state */ 44 u32 pasid; /* PASID index */ 45 bool invalid; /* Used during setup and 46 teardown of the pasid */ 47 spinlock_t lock; /* Protect pri_queues and 48 mmu_notifer_count */ 49 wait_queue_head_t wq; /* To wait for count == 0 */ 50 }; 51 52 struct device_state { 53 struct list_head list; 54 u16 devid; 55 atomic_t count; 56 struct pci_dev *pdev; 57 struct pasid_state **states; 58 struct iommu_domain *domain; 59 int pasid_levels; 60 int max_pasids; 61 amd_iommu_invalid_ppr_cb inv_ppr_cb; 62 amd_iommu_invalidate_ctx inv_ctx_cb; 63 spinlock_t lock; 64 wait_queue_head_t wq; 65 }; 66 67 struct fault { 68 struct work_struct work; 69 struct device_state *dev_state; 70 struct pasid_state *state; 71 struct mm_struct *mm; 72 u64 address; 73 u32 pasid; 74 u16 tag; 75 u16 finish; 76 u16 flags; 77 }; 78 79 static LIST_HEAD(state_list); 80 static DEFINE_SPINLOCK(state_lock); 81 82 static struct workqueue_struct *iommu_wq; 83 84 static void free_pasid_states(struct device_state *dev_state); 85 86 static u16 device_id(struct pci_dev *pdev) 87 { 88 u16 devid; 89 90 devid = pdev->bus->number; 91 devid = (devid << 8) | pdev->devfn; 92 93 return devid; 94 } 95 96 static struct device_state *__get_device_state(u16 devid) 97 { 98 struct device_state *dev_state; 99 100 list_for_each_entry(dev_state, &state_list, list) { 101 if (dev_state->devid == devid) 102 return dev_state; 103 } 104 105 return NULL; 106 } 107 108 static struct device_state *get_device_state(u16 devid) 109 { 110 struct device_state *dev_state; 111 unsigned long flags; 112 113 spin_lock_irqsave(&state_lock, flags); 114 dev_state = __get_device_state(devid); 115 if (dev_state != NULL) 116 atomic_inc(&dev_state->count); 117 spin_unlock_irqrestore(&state_lock, flags); 118 119 return dev_state; 120 } 121 122 static void free_device_state(struct device_state *dev_state) 123 { 124 struct iommu_group *group; 125 126 /* Get rid of any remaining pasid states */ 127 free_pasid_states(dev_state); 128 129 /* 130 * Wait until the last reference is dropped before freeing 131 * the device state. 132 */ 133 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 134 135 /* 136 * First detach device from domain - No more PRI requests will arrive 137 * from that device after it is unbound from the IOMMUv2 domain. 138 */ 139 group = iommu_group_get(&dev_state->pdev->dev); 140 if (WARN_ON(!group)) 141 return; 142 143 iommu_detach_group(dev_state->domain, group); 144 145 iommu_group_put(group); 146 147 /* Everything is down now, free the IOMMUv2 domain */ 148 iommu_domain_free(dev_state->domain); 149 150 /* Finally get rid of the device-state */ 151 kfree(dev_state); 152 } 153 154 static void put_device_state(struct device_state *dev_state) 155 { 156 if (atomic_dec_and_test(&dev_state->count)) 157 wake_up(&dev_state->wq); 158 } 159 160 /* Must be called under dev_state->lock */ 161 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 162 u32 pasid, bool alloc) 163 { 164 struct pasid_state **root, **ptr; 165 int level, index; 166 167 level = dev_state->pasid_levels; 168 root = dev_state->states; 169 170 while (true) { 171 172 index = (pasid >> (9 * level)) & 0x1ff; 173 ptr = &root[index]; 174 175 if (level == 0) 176 break; 177 178 if (*ptr == NULL) { 179 if (!alloc) 180 return NULL; 181 182 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 183 if (*ptr == NULL) 184 return NULL; 185 } 186 187 root = (struct pasid_state **)*ptr; 188 level -= 1; 189 } 190 191 return ptr; 192 } 193 194 static int set_pasid_state(struct device_state *dev_state, 195 struct pasid_state *pasid_state, 196 u32 pasid) 197 { 198 struct pasid_state **ptr; 199 unsigned long flags; 200 int ret; 201 202 spin_lock_irqsave(&dev_state->lock, flags); 203 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 204 205 ret = -ENOMEM; 206 if (ptr == NULL) 207 goto out_unlock; 208 209 ret = -ENOMEM; 210 if (*ptr != NULL) 211 goto out_unlock; 212 213 *ptr = pasid_state; 214 215 ret = 0; 216 217 out_unlock: 218 spin_unlock_irqrestore(&dev_state->lock, flags); 219 220 return ret; 221 } 222 223 static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 224 { 225 struct pasid_state **ptr; 226 unsigned long flags; 227 228 spin_lock_irqsave(&dev_state->lock, flags); 229 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 230 231 if (ptr == NULL) 232 goto out_unlock; 233 234 *ptr = NULL; 235 236 out_unlock: 237 spin_unlock_irqrestore(&dev_state->lock, flags); 238 } 239 240 static struct pasid_state *get_pasid_state(struct device_state *dev_state, 241 u32 pasid) 242 { 243 struct pasid_state **ptr, *ret = NULL; 244 unsigned long flags; 245 246 spin_lock_irqsave(&dev_state->lock, flags); 247 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 248 249 if (ptr == NULL) 250 goto out_unlock; 251 252 ret = *ptr; 253 if (ret) 254 refcount_inc(&ret->count); 255 256 out_unlock: 257 spin_unlock_irqrestore(&dev_state->lock, flags); 258 259 return ret; 260 } 261 262 static void free_pasid_state(struct pasid_state *pasid_state) 263 { 264 kfree(pasid_state); 265 } 266 267 static void put_pasid_state(struct pasid_state *pasid_state) 268 { 269 if (refcount_dec_and_test(&pasid_state->count)) 270 wake_up(&pasid_state->wq); 271 } 272 273 static void put_pasid_state_wait(struct pasid_state *pasid_state) 274 { 275 refcount_dec(&pasid_state->count); 276 wait_event(pasid_state->wq, !refcount_read(&pasid_state->count)); 277 free_pasid_state(pasid_state); 278 } 279 280 static void unbind_pasid(struct pasid_state *pasid_state) 281 { 282 struct iommu_domain *domain; 283 284 domain = pasid_state->device_state->domain; 285 286 /* 287 * Mark pasid_state as invalid, no more faults will we added to the 288 * work queue after this is visible everywhere. 289 */ 290 pasid_state->invalid = true; 291 292 /* Make sure this is visible */ 293 smp_wmb(); 294 295 /* After this the device/pasid can't access the mm anymore */ 296 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 297 298 /* Make sure no more pending faults are in the queue */ 299 flush_workqueue(iommu_wq); 300 } 301 302 static void free_pasid_states_level1(struct pasid_state **tbl) 303 { 304 int i; 305 306 for (i = 0; i < 512; ++i) { 307 if (tbl[i] == NULL) 308 continue; 309 310 free_page((unsigned long)tbl[i]); 311 } 312 } 313 314 static void free_pasid_states_level2(struct pasid_state **tbl) 315 { 316 struct pasid_state **ptr; 317 int i; 318 319 for (i = 0; i < 512; ++i) { 320 if (tbl[i] == NULL) 321 continue; 322 323 ptr = (struct pasid_state **)tbl[i]; 324 free_pasid_states_level1(ptr); 325 } 326 } 327 328 static void free_pasid_states(struct device_state *dev_state) 329 { 330 struct pasid_state *pasid_state; 331 int i; 332 333 for (i = 0; i < dev_state->max_pasids; ++i) { 334 pasid_state = get_pasid_state(dev_state, i); 335 if (pasid_state == NULL) 336 continue; 337 338 put_pasid_state(pasid_state); 339 340 /* 341 * This will call the mn_release function and 342 * unbind the PASID 343 */ 344 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 345 346 put_pasid_state_wait(pasid_state); /* Reference taken in 347 amd_iommu_bind_pasid */ 348 349 /* Drop reference taken in amd_iommu_bind_pasid */ 350 put_device_state(dev_state); 351 } 352 353 if (dev_state->pasid_levels == 2) 354 free_pasid_states_level2(dev_state->states); 355 else if (dev_state->pasid_levels == 1) 356 free_pasid_states_level1(dev_state->states); 357 else 358 BUG_ON(dev_state->pasid_levels != 0); 359 360 free_page((unsigned long)dev_state->states); 361 } 362 363 static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 364 { 365 return container_of(mn, struct pasid_state, mn); 366 } 367 368 static void mn_invalidate_range(struct mmu_notifier *mn, 369 struct mm_struct *mm, 370 unsigned long start, unsigned long end) 371 { 372 struct pasid_state *pasid_state; 373 struct device_state *dev_state; 374 375 pasid_state = mn_to_state(mn); 376 dev_state = pasid_state->device_state; 377 378 if ((start ^ (end - 1)) < PAGE_SIZE) 379 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 380 start); 381 else 382 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 383 } 384 385 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 386 { 387 struct pasid_state *pasid_state; 388 struct device_state *dev_state; 389 bool run_inv_ctx_cb; 390 391 might_sleep(); 392 393 pasid_state = mn_to_state(mn); 394 dev_state = pasid_state->device_state; 395 run_inv_ctx_cb = !pasid_state->invalid; 396 397 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 398 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 399 400 unbind_pasid(pasid_state); 401 } 402 403 static const struct mmu_notifier_ops iommu_mn = { 404 .release = mn_release, 405 .invalidate_range = mn_invalidate_range, 406 }; 407 408 static void set_pri_tag_status(struct pasid_state *pasid_state, 409 u16 tag, int status) 410 { 411 unsigned long flags; 412 413 spin_lock_irqsave(&pasid_state->lock, flags); 414 pasid_state->pri[tag].status = status; 415 spin_unlock_irqrestore(&pasid_state->lock, flags); 416 } 417 418 static void finish_pri_tag(struct device_state *dev_state, 419 struct pasid_state *pasid_state, 420 u16 tag) 421 { 422 unsigned long flags; 423 424 spin_lock_irqsave(&pasid_state->lock, flags); 425 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 426 pasid_state->pri[tag].finish) { 427 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 428 pasid_state->pri[tag].status, tag); 429 pasid_state->pri[tag].finish = false; 430 pasid_state->pri[tag].status = PPR_SUCCESS; 431 } 432 spin_unlock_irqrestore(&pasid_state->lock, flags); 433 } 434 435 static void handle_fault_error(struct fault *fault) 436 { 437 int status; 438 439 if (!fault->dev_state->inv_ppr_cb) { 440 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 441 return; 442 } 443 444 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 445 fault->pasid, 446 fault->address, 447 fault->flags); 448 switch (status) { 449 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 450 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 451 break; 452 case AMD_IOMMU_INV_PRI_RSP_INVALID: 453 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 454 break; 455 case AMD_IOMMU_INV_PRI_RSP_FAIL: 456 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 457 break; 458 default: 459 BUG(); 460 } 461 } 462 463 static bool access_error(struct vm_area_struct *vma, struct fault *fault) 464 { 465 unsigned long requested = 0; 466 467 if (fault->flags & PPR_FAULT_EXEC) 468 requested |= VM_EXEC; 469 470 if (fault->flags & PPR_FAULT_READ) 471 requested |= VM_READ; 472 473 if (fault->flags & PPR_FAULT_WRITE) 474 requested |= VM_WRITE; 475 476 return (requested & ~vma->vm_flags) != 0; 477 } 478 479 static void do_fault(struct work_struct *work) 480 { 481 struct fault *fault = container_of(work, struct fault, work); 482 struct vm_area_struct *vma; 483 vm_fault_t ret = VM_FAULT_ERROR; 484 unsigned int flags = 0; 485 struct mm_struct *mm; 486 u64 address; 487 488 mm = fault->state->mm; 489 address = fault->address; 490 491 if (fault->flags & PPR_FAULT_USER) 492 flags |= FAULT_FLAG_USER; 493 if (fault->flags & PPR_FAULT_WRITE) 494 flags |= FAULT_FLAG_WRITE; 495 flags |= FAULT_FLAG_REMOTE; 496 497 mmap_read_lock(mm); 498 vma = find_extend_vma(mm, address); 499 if (!vma || address < vma->vm_start) 500 /* failed to get a vma in the right range */ 501 goto out; 502 503 /* Check if we have the right permissions on the vma */ 504 if (access_error(vma, fault)) 505 goto out; 506 507 ret = handle_mm_fault(vma, address, flags, NULL); 508 out: 509 mmap_read_unlock(mm); 510 511 if (ret & VM_FAULT_ERROR) 512 /* failed to service fault */ 513 handle_fault_error(fault); 514 515 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 516 517 put_pasid_state(fault->state); 518 519 kfree(fault); 520 } 521 522 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 523 { 524 struct amd_iommu_fault *iommu_fault; 525 struct pasid_state *pasid_state; 526 struct device_state *dev_state; 527 struct pci_dev *pdev = NULL; 528 unsigned long flags; 529 struct fault *fault; 530 bool finish; 531 u16 tag, devid; 532 int ret; 533 534 iommu_fault = data; 535 tag = iommu_fault->tag & 0x1ff; 536 finish = (iommu_fault->tag >> 9) & 1; 537 538 devid = iommu_fault->device_id; 539 pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid), 540 devid & 0xff); 541 if (!pdev) 542 return -ENODEV; 543 544 ret = NOTIFY_DONE; 545 546 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 547 if (amd_iommu_is_attach_deferred(&pdev->dev)) { 548 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 549 PPR_INVALID, tag); 550 goto out; 551 } 552 553 dev_state = get_device_state(iommu_fault->device_id); 554 if (dev_state == NULL) 555 goto out; 556 557 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 558 if (pasid_state == NULL || pasid_state->invalid) { 559 /* We know the device but not the PASID -> send INVALID */ 560 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 561 PPR_INVALID, tag); 562 goto out_drop_state; 563 } 564 565 spin_lock_irqsave(&pasid_state->lock, flags); 566 atomic_inc(&pasid_state->pri[tag].inflight); 567 if (finish) 568 pasid_state->pri[tag].finish = true; 569 spin_unlock_irqrestore(&pasid_state->lock, flags); 570 571 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 572 if (fault == NULL) { 573 /* We are OOM - send success and let the device re-fault */ 574 finish_pri_tag(dev_state, pasid_state, tag); 575 goto out_drop_state; 576 } 577 578 fault->dev_state = dev_state; 579 fault->address = iommu_fault->address; 580 fault->state = pasid_state; 581 fault->tag = tag; 582 fault->finish = finish; 583 fault->pasid = iommu_fault->pasid; 584 fault->flags = iommu_fault->flags; 585 INIT_WORK(&fault->work, do_fault); 586 587 queue_work(iommu_wq, &fault->work); 588 589 ret = NOTIFY_OK; 590 591 out_drop_state: 592 593 if (ret != NOTIFY_OK && pasid_state) 594 put_pasid_state(pasid_state); 595 596 put_device_state(dev_state); 597 598 out: 599 return ret; 600 } 601 602 static struct notifier_block ppr_nb = { 603 .notifier_call = ppr_notifier, 604 }; 605 606 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 607 struct task_struct *task) 608 { 609 struct pasid_state *pasid_state; 610 struct device_state *dev_state; 611 struct mm_struct *mm; 612 u16 devid; 613 int ret; 614 615 might_sleep(); 616 617 if (!amd_iommu_v2_supported()) 618 return -ENODEV; 619 620 devid = device_id(pdev); 621 dev_state = get_device_state(devid); 622 623 if (dev_state == NULL) 624 return -EINVAL; 625 626 ret = -EINVAL; 627 if (pasid >= dev_state->max_pasids) 628 goto out; 629 630 ret = -ENOMEM; 631 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 632 if (pasid_state == NULL) 633 goto out; 634 635 636 refcount_set(&pasid_state->count, 1); 637 init_waitqueue_head(&pasid_state->wq); 638 spin_lock_init(&pasid_state->lock); 639 640 mm = get_task_mm(task); 641 pasid_state->mm = mm; 642 pasid_state->device_state = dev_state; 643 pasid_state->pasid = pasid; 644 pasid_state->invalid = true; /* Mark as valid only if we are 645 done with setting up the pasid */ 646 pasid_state->mn.ops = &iommu_mn; 647 648 if (pasid_state->mm == NULL) 649 goto out_free; 650 651 mmu_notifier_register(&pasid_state->mn, mm); 652 653 ret = set_pasid_state(dev_state, pasid_state, pasid); 654 if (ret) 655 goto out_unregister; 656 657 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 658 __pa(pasid_state->mm->pgd)); 659 if (ret) 660 goto out_clear_state; 661 662 /* Now we are ready to handle faults */ 663 pasid_state->invalid = false; 664 665 /* 666 * Drop the reference to the mm_struct here. We rely on the 667 * mmu_notifier release call-back to inform us when the mm 668 * is going away. 669 */ 670 mmput(mm); 671 672 return 0; 673 674 out_clear_state: 675 clear_pasid_state(dev_state, pasid); 676 677 out_unregister: 678 mmu_notifier_unregister(&pasid_state->mn, mm); 679 mmput(mm); 680 681 out_free: 682 free_pasid_state(pasid_state); 683 684 out: 685 put_device_state(dev_state); 686 687 return ret; 688 } 689 EXPORT_SYMBOL(amd_iommu_bind_pasid); 690 691 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 692 { 693 struct pasid_state *pasid_state; 694 struct device_state *dev_state; 695 u16 devid; 696 697 might_sleep(); 698 699 if (!amd_iommu_v2_supported()) 700 return; 701 702 devid = device_id(pdev); 703 dev_state = get_device_state(devid); 704 if (dev_state == NULL) 705 return; 706 707 if (pasid >= dev_state->max_pasids) 708 goto out; 709 710 pasid_state = get_pasid_state(dev_state, pasid); 711 if (pasid_state == NULL) 712 goto out; 713 /* 714 * Drop reference taken here. We are safe because we still hold 715 * the reference taken in the amd_iommu_bind_pasid function. 716 */ 717 put_pasid_state(pasid_state); 718 719 /* Clear the pasid state so that the pasid can be re-used */ 720 clear_pasid_state(dev_state, pasid_state->pasid); 721 722 /* 723 * Call mmu_notifier_unregister to drop our reference 724 * to pasid_state->mm 725 */ 726 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 727 728 put_pasid_state_wait(pasid_state); /* Reference taken in 729 amd_iommu_bind_pasid */ 730 out: 731 /* Drop reference taken in this function */ 732 put_device_state(dev_state); 733 734 /* Drop reference taken in amd_iommu_bind_pasid */ 735 put_device_state(dev_state); 736 } 737 EXPORT_SYMBOL(amd_iommu_unbind_pasid); 738 739 int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 740 { 741 struct device_state *dev_state; 742 struct iommu_group *group; 743 unsigned long flags; 744 int ret, tmp; 745 u16 devid; 746 747 might_sleep(); 748 749 /* 750 * When memory encryption is active the device is likely not in a 751 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 752 */ 753 if (cc_platform_has(CC_ATTR_MEM_ENCRYPT)) 754 return -ENODEV; 755 756 if (!amd_iommu_v2_supported()) 757 return -ENODEV; 758 759 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 760 return -EINVAL; 761 762 devid = device_id(pdev); 763 764 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 765 if (dev_state == NULL) 766 return -ENOMEM; 767 768 spin_lock_init(&dev_state->lock); 769 init_waitqueue_head(&dev_state->wq); 770 dev_state->pdev = pdev; 771 dev_state->devid = devid; 772 773 tmp = pasids; 774 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 775 dev_state->pasid_levels += 1; 776 777 atomic_set(&dev_state->count, 1); 778 dev_state->max_pasids = pasids; 779 780 ret = -ENOMEM; 781 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 782 if (dev_state->states == NULL) 783 goto out_free_dev_state; 784 785 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 786 if (dev_state->domain == NULL) 787 goto out_free_states; 788 789 amd_iommu_domain_direct_map(dev_state->domain); 790 791 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 792 if (ret) 793 goto out_free_domain; 794 795 group = iommu_group_get(&pdev->dev); 796 if (!group) { 797 ret = -EINVAL; 798 goto out_free_domain; 799 } 800 801 ret = iommu_attach_group(dev_state->domain, group); 802 if (ret != 0) 803 goto out_drop_group; 804 805 iommu_group_put(group); 806 807 spin_lock_irqsave(&state_lock, flags); 808 809 if (__get_device_state(devid) != NULL) { 810 spin_unlock_irqrestore(&state_lock, flags); 811 ret = -EBUSY; 812 goto out_free_domain; 813 } 814 815 list_add_tail(&dev_state->list, &state_list); 816 817 spin_unlock_irqrestore(&state_lock, flags); 818 819 return 0; 820 821 out_drop_group: 822 iommu_group_put(group); 823 824 out_free_domain: 825 iommu_domain_free(dev_state->domain); 826 827 out_free_states: 828 free_page((unsigned long)dev_state->states); 829 830 out_free_dev_state: 831 kfree(dev_state); 832 833 return ret; 834 } 835 EXPORT_SYMBOL(amd_iommu_init_device); 836 837 void amd_iommu_free_device(struct pci_dev *pdev) 838 { 839 struct device_state *dev_state; 840 unsigned long flags; 841 u16 devid; 842 843 if (!amd_iommu_v2_supported()) 844 return; 845 846 devid = device_id(pdev); 847 848 spin_lock_irqsave(&state_lock, flags); 849 850 dev_state = __get_device_state(devid); 851 if (dev_state == NULL) { 852 spin_unlock_irqrestore(&state_lock, flags); 853 return; 854 } 855 856 list_del(&dev_state->list); 857 858 spin_unlock_irqrestore(&state_lock, flags); 859 860 put_device_state(dev_state); 861 free_device_state(dev_state); 862 } 863 EXPORT_SYMBOL(amd_iommu_free_device); 864 865 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 866 amd_iommu_invalid_ppr_cb cb) 867 { 868 struct device_state *dev_state; 869 unsigned long flags; 870 u16 devid; 871 int ret; 872 873 if (!amd_iommu_v2_supported()) 874 return -ENODEV; 875 876 devid = device_id(pdev); 877 878 spin_lock_irqsave(&state_lock, flags); 879 880 ret = -EINVAL; 881 dev_state = __get_device_state(devid); 882 if (dev_state == NULL) 883 goto out_unlock; 884 885 dev_state->inv_ppr_cb = cb; 886 887 ret = 0; 888 889 out_unlock: 890 spin_unlock_irqrestore(&state_lock, flags); 891 892 return ret; 893 } 894 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 895 896 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 897 amd_iommu_invalidate_ctx cb) 898 { 899 struct device_state *dev_state; 900 unsigned long flags; 901 u16 devid; 902 int ret; 903 904 if (!amd_iommu_v2_supported()) 905 return -ENODEV; 906 907 devid = device_id(pdev); 908 909 spin_lock_irqsave(&state_lock, flags); 910 911 ret = -EINVAL; 912 dev_state = __get_device_state(devid); 913 if (dev_state == NULL) 914 goto out_unlock; 915 916 dev_state->inv_ctx_cb = cb; 917 918 ret = 0; 919 920 out_unlock: 921 spin_unlock_irqrestore(&state_lock, flags); 922 923 return ret; 924 } 925 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 926 927 static int __init amd_iommu_v2_init(void) 928 { 929 int ret; 930 931 if (!amd_iommu_v2_supported()) { 932 pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n"); 933 /* 934 * Load anyway to provide the symbols to other modules 935 * which may use AMD IOMMUv2 optionally. 936 */ 937 return 0; 938 } 939 940 ret = -ENOMEM; 941 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 942 if (iommu_wq == NULL) 943 goto out; 944 945 amd_iommu_register_ppr_notifier(&ppr_nb); 946 947 pr_info("AMD IOMMUv2 loaded and initialized\n"); 948 949 return 0; 950 951 out: 952 return ret; 953 } 954 955 static void __exit amd_iommu_v2_exit(void) 956 { 957 struct device_state *dev_state, *next; 958 unsigned long flags; 959 LIST_HEAD(freelist); 960 961 if (!amd_iommu_v2_supported()) 962 return; 963 964 amd_iommu_unregister_ppr_notifier(&ppr_nb); 965 966 flush_workqueue(iommu_wq); 967 968 /* 969 * The loop below might call flush_workqueue(), so call 970 * destroy_workqueue() after it 971 */ 972 spin_lock_irqsave(&state_lock, flags); 973 974 list_for_each_entry_safe(dev_state, next, &state_list, list) { 975 WARN_ON_ONCE(1); 976 977 put_device_state(dev_state); 978 list_del(&dev_state->list); 979 list_add_tail(&dev_state->list, &freelist); 980 } 981 982 spin_unlock_irqrestore(&state_lock, flags); 983 984 /* 985 * Since free_device_state waits on the count to be zero, 986 * we need to free dev_state outside the spinlock. 987 */ 988 list_for_each_entry_safe(dev_state, next, &freelist, list) { 989 list_del(&dev_state->list); 990 free_device_state(dev_state); 991 } 992 993 destroy_workqueue(iommu_wq); 994 } 995 996 module_init(amd_iommu_v2_init); 997 module_exit(amd_iommu_v2_exit); 998