1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7 #define pr_fmt(fmt) "AMD-Vi: " fmt 8 9 #include <linux/refcount.h> 10 #include <linux/mmu_notifier.h> 11 #include <linux/amd-iommu.h> 12 #include <linux/mm_types.h> 13 #include <linux/profile.h> 14 #include <linux/module.h> 15 #include <linux/sched.h> 16 #include <linux/sched/mm.h> 17 #include <linux/wait.h> 18 #include <linux/pci.h> 19 #include <linux/gfp.h> 20 #include <linux/cc_platform.h> 21 22 #include "amd_iommu.h" 23 24 MODULE_LICENSE("GPL v2"); 25 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 26 27 #define PRI_QUEUE_SIZE 512 28 29 struct pri_queue { 30 atomic_t inflight; 31 bool finish; 32 int status; 33 }; 34 35 struct pasid_state { 36 struct list_head list; /* For global state-list */ 37 refcount_t count; /* Reference count */ 38 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 39 calls */ 40 struct mm_struct *mm; /* mm_struct for the faults */ 41 struct mmu_notifier mn; /* mmu_notifier handle */ 42 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 43 struct device_state *device_state; /* Link to our device_state */ 44 u32 pasid; /* PASID index */ 45 bool invalid; /* Used during setup and 46 teardown of the pasid */ 47 spinlock_t lock; /* Protect pri_queues and 48 mmu_notifer_count */ 49 wait_queue_head_t wq; /* To wait for count == 0 */ 50 }; 51 52 struct device_state { 53 struct list_head list; 54 u32 sbdf; 55 atomic_t count; 56 struct pci_dev *pdev; 57 struct pasid_state **states; 58 struct iommu_domain *domain; 59 int pasid_levels; 60 int max_pasids; 61 amd_iommu_invalid_ppr_cb inv_ppr_cb; 62 amd_iommu_invalidate_ctx inv_ctx_cb; 63 spinlock_t lock; 64 wait_queue_head_t wq; 65 }; 66 67 struct fault { 68 struct work_struct work; 69 struct device_state *dev_state; 70 struct pasid_state *state; 71 struct mm_struct *mm; 72 u64 address; 73 u32 pasid; 74 u16 tag; 75 u16 finish; 76 u16 flags; 77 }; 78 79 static LIST_HEAD(state_list); 80 static DEFINE_SPINLOCK(state_lock); 81 82 static struct workqueue_struct *iommu_wq; 83 84 static void free_pasid_states(struct device_state *dev_state); 85 86 static struct device_state *__get_device_state(u32 sbdf) 87 { 88 struct device_state *dev_state; 89 90 list_for_each_entry(dev_state, &state_list, list) { 91 if (dev_state->sbdf == sbdf) 92 return dev_state; 93 } 94 95 return NULL; 96 } 97 98 static struct device_state *get_device_state(u32 sbdf) 99 { 100 struct device_state *dev_state; 101 unsigned long flags; 102 103 spin_lock_irqsave(&state_lock, flags); 104 dev_state = __get_device_state(sbdf); 105 if (dev_state != NULL) 106 atomic_inc(&dev_state->count); 107 spin_unlock_irqrestore(&state_lock, flags); 108 109 return dev_state; 110 } 111 112 static void free_device_state(struct device_state *dev_state) 113 { 114 struct iommu_group *group; 115 116 /* Get rid of any remaining pasid states */ 117 free_pasid_states(dev_state); 118 119 /* 120 * Wait until the last reference is dropped before freeing 121 * the device state. 122 */ 123 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 124 125 /* 126 * First detach device from domain - No more PRI requests will arrive 127 * from that device after it is unbound from the IOMMUv2 domain. 128 */ 129 group = iommu_group_get(&dev_state->pdev->dev); 130 if (WARN_ON(!group)) 131 return; 132 133 iommu_detach_group(dev_state->domain, group); 134 135 iommu_group_put(group); 136 137 /* Everything is down now, free the IOMMUv2 domain */ 138 iommu_domain_free(dev_state->domain); 139 140 /* Finally get rid of the device-state */ 141 kfree(dev_state); 142 } 143 144 static void put_device_state(struct device_state *dev_state) 145 { 146 if (atomic_dec_and_test(&dev_state->count)) 147 wake_up(&dev_state->wq); 148 } 149 150 /* Must be called under dev_state->lock */ 151 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 152 u32 pasid, bool alloc) 153 { 154 struct pasid_state **root, **ptr; 155 int level, index; 156 157 level = dev_state->pasid_levels; 158 root = dev_state->states; 159 160 while (true) { 161 162 index = (pasid >> (9 * level)) & 0x1ff; 163 ptr = &root[index]; 164 165 if (level == 0) 166 break; 167 168 if (*ptr == NULL) { 169 if (!alloc) 170 return NULL; 171 172 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 173 if (*ptr == NULL) 174 return NULL; 175 } 176 177 root = (struct pasid_state **)*ptr; 178 level -= 1; 179 } 180 181 return ptr; 182 } 183 184 static int set_pasid_state(struct device_state *dev_state, 185 struct pasid_state *pasid_state, 186 u32 pasid) 187 { 188 struct pasid_state **ptr; 189 unsigned long flags; 190 int ret; 191 192 spin_lock_irqsave(&dev_state->lock, flags); 193 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 194 195 ret = -ENOMEM; 196 if (ptr == NULL) 197 goto out_unlock; 198 199 ret = -ENOMEM; 200 if (*ptr != NULL) 201 goto out_unlock; 202 203 *ptr = pasid_state; 204 205 ret = 0; 206 207 out_unlock: 208 spin_unlock_irqrestore(&dev_state->lock, flags); 209 210 return ret; 211 } 212 213 static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 214 { 215 struct pasid_state **ptr; 216 unsigned long flags; 217 218 spin_lock_irqsave(&dev_state->lock, flags); 219 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 220 221 if (ptr == NULL) 222 goto out_unlock; 223 224 *ptr = NULL; 225 226 out_unlock: 227 spin_unlock_irqrestore(&dev_state->lock, flags); 228 } 229 230 static struct pasid_state *get_pasid_state(struct device_state *dev_state, 231 u32 pasid) 232 { 233 struct pasid_state **ptr, *ret = NULL; 234 unsigned long flags; 235 236 spin_lock_irqsave(&dev_state->lock, flags); 237 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 238 239 if (ptr == NULL) 240 goto out_unlock; 241 242 ret = *ptr; 243 if (ret) 244 refcount_inc(&ret->count); 245 246 out_unlock: 247 spin_unlock_irqrestore(&dev_state->lock, flags); 248 249 return ret; 250 } 251 252 static void free_pasid_state(struct pasid_state *pasid_state) 253 { 254 kfree(pasid_state); 255 } 256 257 static void put_pasid_state(struct pasid_state *pasid_state) 258 { 259 if (refcount_dec_and_test(&pasid_state->count)) 260 wake_up(&pasid_state->wq); 261 } 262 263 static void put_pasid_state_wait(struct pasid_state *pasid_state) 264 { 265 refcount_dec(&pasid_state->count); 266 wait_event(pasid_state->wq, !refcount_read(&pasid_state->count)); 267 free_pasid_state(pasid_state); 268 } 269 270 static void unbind_pasid(struct pasid_state *pasid_state) 271 { 272 struct iommu_domain *domain; 273 274 domain = pasid_state->device_state->domain; 275 276 /* 277 * Mark pasid_state as invalid, no more faults will we added to the 278 * work queue after this is visible everywhere. 279 */ 280 pasid_state->invalid = true; 281 282 /* Make sure this is visible */ 283 smp_wmb(); 284 285 /* After this the device/pasid can't access the mm anymore */ 286 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 287 288 /* Make sure no more pending faults are in the queue */ 289 flush_workqueue(iommu_wq); 290 } 291 292 static void free_pasid_states_level1(struct pasid_state **tbl) 293 { 294 int i; 295 296 for (i = 0; i < 512; ++i) { 297 if (tbl[i] == NULL) 298 continue; 299 300 free_page((unsigned long)tbl[i]); 301 } 302 } 303 304 static void free_pasid_states_level2(struct pasid_state **tbl) 305 { 306 struct pasid_state **ptr; 307 int i; 308 309 for (i = 0; i < 512; ++i) { 310 if (tbl[i] == NULL) 311 continue; 312 313 ptr = (struct pasid_state **)tbl[i]; 314 free_pasid_states_level1(ptr); 315 } 316 } 317 318 static void free_pasid_states(struct device_state *dev_state) 319 { 320 struct pasid_state *pasid_state; 321 int i; 322 323 for (i = 0; i < dev_state->max_pasids; ++i) { 324 pasid_state = get_pasid_state(dev_state, i); 325 if (pasid_state == NULL) 326 continue; 327 328 put_pasid_state(pasid_state); 329 330 /* 331 * This will call the mn_release function and 332 * unbind the PASID 333 */ 334 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 335 336 put_pasid_state_wait(pasid_state); /* Reference taken in 337 amd_iommu_bind_pasid */ 338 339 /* Drop reference taken in amd_iommu_bind_pasid */ 340 put_device_state(dev_state); 341 } 342 343 if (dev_state->pasid_levels == 2) 344 free_pasid_states_level2(dev_state->states); 345 else if (dev_state->pasid_levels == 1) 346 free_pasid_states_level1(dev_state->states); 347 else 348 BUG_ON(dev_state->pasid_levels != 0); 349 350 free_page((unsigned long)dev_state->states); 351 } 352 353 static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 354 { 355 return container_of(mn, struct pasid_state, mn); 356 } 357 358 static void mn_invalidate_range(struct mmu_notifier *mn, 359 struct mm_struct *mm, 360 unsigned long start, unsigned long end) 361 { 362 struct pasid_state *pasid_state; 363 struct device_state *dev_state; 364 365 pasid_state = mn_to_state(mn); 366 dev_state = pasid_state->device_state; 367 368 if ((start ^ (end - 1)) < PAGE_SIZE) 369 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 370 start); 371 else 372 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 373 } 374 375 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 376 { 377 struct pasid_state *pasid_state; 378 struct device_state *dev_state; 379 bool run_inv_ctx_cb; 380 381 might_sleep(); 382 383 pasid_state = mn_to_state(mn); 384 dev_state = pasid_state->device_state; 385 run_inv_ctx_cb = !pasid_state->invalid; 386 387 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 388 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 389 390 unbind_pasid(pasid_state); 391 } 392 393 static const struct mmu_notifier_ops iommu_mn = { 394 .release = mn_release, 395 .invalidate_range = mn_invalidate_range, 396 }; 397 398 static void set_pri_tag_status(struct pasid_state *pasid_state, 399 u16 tag, int status) 400 { 401 unsigned long flags; 402 403 spin_lock_irqsave(&pasid_state->lock, flags); 404 pasid_state->pri[tag].status = status; 405 spin_unlock_irqrestore(&pasid_state->lock, flags); 406 } 407 408 static void finish_pri_tag(struct device_state *dev_state, 409 struct pasid_state *pasid_state, 410 u16 tag) 411 { 412 unsigned long flags; 413 414 spin_lock_irqsave(&pasid_state->lock, flags); 415 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 416 pasid_state->pri[tag].finish) { 417 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 418 pasid_state->pri[tag].status, tag); 419 pasid_state->pri[tag].finish = false; 420 pasid_state->pri[tag].status = PPR_SUCCESS; 421 } 422 spin_unlock_irqrestore(&pasid_state->lock, flags); 423 } 424 425 static void handle_fault_error(struct fault *fault) 426 { 427 int status; 428 429 if (!fault->dev_state->inv_ppr_cb) { 430 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 431 return; 432 } 433 434 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 435 fault->pasid, 436 fault->address, 437 fault->flags); 438 switch (status) { 439 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 440 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 441 break; 442 case AMD_IOMMU_INV_PRI_RSP_INVALID: 443 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 444 break; 445 case AMD_IOMMU_INV_PRI_RSP_FAIL: 446 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 447 break; 448 default: 449 BUG(); 450 } 451 } 452 453 static bool access_error(struct vm_area_struct *vma, struct fault *fault) 454 { 455 unsigned long requested = 0; 456 457 if (fault->flags & PPR_FAULT_EXEC) 458 requested |= VM_EXEC; 459 460 if (fault->flags & PPR_FAULT_READ) 461 requested |= VM_READ; 462 463 if (fault->flags & PPR_FAULT_WRITE) 464 requested |= VM_WRITE; 465 466 return (requested & ~vma->vm_flags) != 0; 467 } 468 469 static void do_fault(struct work_struct *work) 470 { 471 struct fault *fault = container_of(work, struct fault, work); 472 struct vm_area_struct *vma; 473 vm_fault_t ret = VM_FAULT_ERROR; 474 unsigned int flags = 0; 475 struct mm_struct *mm; 476 u64 address; 477 478 mm = fault->state->mm; 479 address = fault->address; 480 481 if (fault->flags & PPR_FAULT_USER) 482 flags |= FAULT_FLAG_USER; 483 if (fault->flags & PPR_FAULT_WRITE) 484 flags |= FAULT_FLAG_WRITE; 485 flags |= FAULT_FLAG_REMOTE; 486 487 mmap_read_lock(mm); 488 vma = find_extend_vma(mm, address); 489 if (!vma || address < vma->vm_start) 490 /* failed to get a vma in the right range */ 491 goto out; 492 493 /* Check if we have the right permissions on the vma */ 494 if (access_error(vma, fault)) 495 goto out; 496 497 ret = handle_mm_fault(vma, address, flags, NULL); 498 out: 499 mmap_read_unlock(mm); 500 501 if (ret & VM_FAULT_ERROR) 502 /* failed to service fault */ 503 handle_fault_error(fault); 504 505 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 506 507 put_pasid_state(fault->state); 508 509 kfree(fault); 510 } 511 512 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 513 { 514 struct amd_iommu_fault *iommu_fault; 515 struct pasid_state *pasid_state; 516 struct device_state *dev_state; 517 struct pci_dev *pdev = NULL; 518 unsigned long flags; 519 struct fault *fault; 520 bool finish; 521 u16 tag, devid, seg_id; 522 int ret; 523 524 iommu_fault = data; 525 tag = iommu_fault->tag & 0x1ff; 526 finish = (iommu_fault->tag >> 9) & 1; 527 528 seg_id = PCI_SBDF_TO_SEGID(iommu_fault->sbdf); 529 devid = PCI_SBDF_TO_DEVID(iommu_fault->sbdf); 530 pdev = pci_get_domain_bus_and_slot(seg_id, PCI_BUS_NUM(devid), 531 devid & 0xff); 532 if (!pdev) 533 return -ENODEV; 534 535 ret = NOTIFY_DONE; 536 537 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 538 if (amd_iommu_is_attach_deferred(&pdev->dev)) { 539 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 540 PPR_INVALID, tag); 541 goto out; 542 } 543 544 dev_state = get_device_state(iommu_fault->sbdf); 545 if (dev_state == NULL) 546 goto out; 547 548 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 549 if (pasid_state == NULL || pasid_state->invalid) { 550 /* We know the device but not the PASID -> send INVALID */ 551 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 552 PPR_INVALID, tag); 553 goto out_drop_state; 554 } 555 556 spin_lock_irqsave(&pasid_state->lock, flags); 557 atomic_inc(&pasid_state->pri[tag].inflight); 558 if (finish) 559 pasid_state->pri[tag].finish = true; 560 spin_unlock_irqrestore(&pasid_state->lock, flags); 561 562 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 563 if (fault == NULL) { 564 /* We are OOM - send success and let the device re-fault */ 565 finish_pri_tag(dev_state, pasid_state, tag); 566 goto out_drop_state; 567 } 568 569 fault->dev_state = dev_state; 570 fault->address = iommu_fault->address; 571 fault->state = pasid_state; 572 fault->tag = tag; 573 fault->finish = finish; 574 fault->pasid = iommu_fault->pasid; 575 fault->flags = iommu_fault->flags; 576 INIT_WORK(&fault->work, do_fault); 577 578 queue_work(iommu_wq, &fault->work); 579 580 ret = NOTIFY_OK; 581 582 out_drop_state: 583 584 if (ret != NOTIFY_OK && pasid_state) 585 put_pasid_state(pasid_state); 586 587 put_device_state(dev_state); 588 589 out: 590 return ret; 591 } 592 593 static struct notifier_block ppr_nb = { 594 .notifier_call = ppr_notifier, 595 }; 596 597 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 598 struct task_struct *task) 599 { 600 struct pasid_state *pasid_state; 601 struct device_state *dev_state; 602 struct mm_struct *mm; 603 u32 sbdf; 604 int ret; 605 606 might_sleep(); 607 608 if (!amd_iommu_v2_supported()) 609 return -ENODEV; 610 611 sbdf = get_pci_sbdf_id(pdev); 612 dev_state = get_device_state(sbdf); 613 614 if (dev_state == NULL) 615 return -EINVAL; 616 617 ret = -EINVAL; 618 if (pasid >= dev_state->max_pasids) 619 goto out; 620 621 ret = -ENOMEM; 622 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 623 if (pasid_state == NULL) 624 goto out; 625 626 627 refcount_set(&pasid_state->count, 1); 628 init_waitqueue_head(&pasid_state->wq); 629 spin_lock_init(&pasid_state->lock); 630 631 mm = get_task_mm(task); 632 pasid_state->mm = mm; 633 pasid_state->device_state = dev_state; 634 pasid_state->pasid = pasid; 635 pasid_state->invalid = true; /* Mark as valid only if we are 636 done with setting up the pasid */ 637 pasid_state->mn.ops = &iommu_mn; 638 639 if (pasid_state->mm == NULL) 640 goto out_free; 641 642 mmu_notifier_register(&pasid_state->mn, mm); 643 644 ret = set_pasid_state(dev_state, pasid_state, pasid); 645 if (ret) 646 goto out_unregister; 647 648 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 649 __pa(pasid_state->mm->pgd)); 650 if (ret) 651 goto out_clear_state; 652 653 /* Now we are ready to handle faults */ 654 pasid_state->invalid = false; 655 656 /* 657 * Drop the reference to the mm_struct here. We rely on the 658 * mmu_notifier release call-back to inform us when the mm 659 * is going away. 660 */ 661 mmput(mm); 662 663 return 0; 664 665 out_clear_state: 666 clear_pasid_state(dev_state, pasid); 667 668 out_unregister: 669 mmu_notifier_unregister(&pasid_state->mn, mm); 670 mmput(mm); 671 672 out_free: 673 free_pasid_state(pasid_state); 674 675 out: 676 put_device_state(dev_state); 677 678 return ret; 679 } 680 EXPORT_SYMBOL(amd_iommu_bind_pasid); 681 682 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 683 { 684 struct pasid_state *pasid_state; 685 struct device_state *dev_state; 686 u32 sbdf; 687 688 might_sleep(); 689 690 if (!amd_iommu_v2_supported()) 691 return; 692 693 sbdf = get_pci_sbdf_id(pdev); 694 dev_state = get_device_state(sbdf); 695 if (dev_state == NULL) 696 return; 697 698 if (pasid >= dev_state->max_pasids) 699 goto out; 700 701 pasid_state = get_pasid_state(dev_state, pasid); 702 if (pasid_state == NULL) 703 goto out; 704 /* 705 * Drop reference taken here. We are safe because we still hold 706 * the reference taken in the amd_iommu_bind_pasid function. 707 */ 708 put_pasid_state(pasid_state); 709 710 /* Clear the pasid state so that the pasid can be re-used */ 711 clear_pasid_state(dev_state, pasid_state->pasid); 712 713 /* 714 * Call mmu_notifier_unregister to drop our reference 715 * to pasid_state->mm 716 */ 717 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 718 719 put_pasid_state_wait(pasid_state); /* Reference taken in 720 amd_iommu_bind_pasid */ 721 out: 722 /* Drop reference taken in this function */ 723 put_device_state(dev_state); 724 725 /* Drop reference taken in amd_iommu_bind_pasid */ 726 put_device_state(dev_state); 727 } 728 EXPORT_SYMBOL(amd_iommu_unbind_pasid); 729 730 int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 731 { 732 struct device_state *dev_state; 733 struct iommu_group *group; 734 unsigned long flags; 735 int ret, tmp; 736 u32 sbdf; 737 738 might_sleep(); 739 740 /* 741 * When memory encryption is active the device is likely not in a 742 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 743 */ 744 if (cc_platform_has(CC_ATTR_MEM_ENCRYPT)) 745 return -ENODEV; 746 747 if (!amd_iommu_v2_supported()) 748 return -ENODEV; 749 750 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 751 return -EINVAL; 752 753 sbdf = get_pci_sbdf_id(pdev); 754 755 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 756 if (dev_state == NULL) 757 return -ENOMEM; 758 759 spin_lock_init(&dev_state->lock); 760 init_waitqueue_head(&dev_state->wq); 761 dev_state->pdev = pdev; 762 dev_state->sbdf = sbdf; 763 764 tmp = pasids; 765 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 766 dev_state->pasid_levels += 1; 767 768 atomic_set(&dev_state->count, 1); 769 dev_state->max_pasids = pasids; 770 771 ret = -ENOMEM; 772 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 773 if (dev_state->states == NULL) 774 goto out_free_dev_state; 775 776 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 777 if (dev_state->domain == NULL) 778 goto out_free_states; 779 780 /* See iommu_is_default_domain() */ 781 dev_state->domain->type = IOMMU_DOMAIN_IDENTITY; 782 amd_iommu_domain_direct_map(dev_state->domain); 783 784 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 785 if (ret) 786 goto out_free_domain; 787 788 group = iommu_group_get(&pdev->dev); 789 if (!group) { 790 ret = -EINVAL; 791 goto out_free_domain; 792 } 793 794 ret = iommu_attach_group(dev_state->domain, group); 795 if (ret != 0) 796 goto out_drop_group; 797 798 iommu_group_put(group); 799 800 spin_lock_irqsave(&state_lock, flags); 801 802 if (__get_device_state(sbdf) != NULL) { 803 spin_unlock_irqrestore(&state_lock, flags); 804 ret = -EBUSY; 805 goto out_free_domain; 806 } 807 808 list_add_tail(&dev_state->list, &state_list); 809 810 spin_unlock_irqrestore(&state_lock, flags); 811 812 return 0; 813 814 out_drop_group: 815 iommu_group_put(group); 816 817 out_free_domain: 818 iommu_domain_free(dev_state->domain); 819 820 out_free_states: 821 free_page((unsigned long)dev_state->states); 822 823 out_free_dev_state: 824 kfree(dev_state); 825 826 return ret; 827 } 828 EXPORT_SYMBOL(amd_iommu_init_device); 829 830 void amd_iommu_free_device(struct pci_dev *pdev) 831 { 832 struct device_state *dev_state; 833 unsigned long flags; 834 u32 sbdf; 835 836 if (!amd_iommu_v2_supported()) 837 return; 838 839 sbdf = get_pci_sbdf_id(pdev); 840 841 spin_lock_irqsave(&state_lock, flags); 842 843 dev_state = __get_device_state(sbdf); 844 if (dev_state == NULL) { 845 spin_unlock_irqrestore(&state_lock, flags); 846 return; 847 } 848 849 list_del(&dev_state->list); 850 851 spin_unlock_irqrestore(&state_lock, flags); 852 853 put_device_state(dev_state); 854 free_device_state(dev_state); 855 } 856 EXPORT_SYMBOL(amd_iommu_free_device); 857 858 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 859 amd_iommu_invalid_ppr_cb cb) 860 { 861 struct device_state *dev_state; 862 unsigned long flags; 863 u32 sbdf; 864 int ret; 865 866 if (!amd_iommu_v2_supported()) 867 return -ENODEV; 868 869 sbdf = get_pci_sbdf_id(pdev); 870 871 spin_lock_irqsave(&state_lock, flags); 872 873 ret = -EINVAL; 874 dev_state = __get_device_state(sbdf); 875 if (dev_state == NULL) 876 goto out_unlock; 877 878 dev_state->inv_ppr_cb = cb; 879 880 ret = 0; 881 882 out_unlock: 883 spin_unlock_irqrestore(&state_lock, flags); 884 885 return ret; 886 } 887 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 888 889 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 890 amd_iommu_invalidate_ctx cb) 891 { 892 struct device_state *dev_state; 893 unsigned long flags; 894 u32 sbdf; 895 int ret; 896 897 if (!amd_iommu_v2_supported()) 898 return -ENODEV; 899 900 sbdf = get_pci_sbdf_id(pdev); 901 902 spin_lock_irqsave(&state_lock, flags); 903 904 ret = -EINVAL; 905 dev_state = __get_device_state(sbdf); 906 if (dev_state == NULL) 907 goto out_unlock; 908 909 dev_state->inv_ctx_cb = cb; 910 911 ret = 0; 912 913 out_unlock: 914 spin_unlock_irqrestore(&state_lock, flags); 915 916 return ret; 917 } 918 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 919 920 static int __init amd_iommu_v2_init(void) 921 { 922 int ret; 923 924 if (!amd_iommu_v2_supported()) { 925 pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n"); 926 /* 927 * Load anyway to provide the symbols to other modules 928 * which may use AMD IOMMUv2 optionally. 929 */ 930 return 0; 931 } 932 933 ret = -ENOMEM; 934 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 935 if (iommu_wq == NULL) 936 goto out; 937 938 amd_iommu_register_ppr_notifier(&ppr_nb); 939 940 pr_info("AMD IOMMUv2 loaded and initialized\n"); 941 942 return 0; 943 944 out: 945 return ret; 946 } 947 948 static void __exit amd_iommu_v2_exit(void) 949 { 950 struct device_state *dev_state, *next; 951 unsigned long flags; 952 LIST_HEAD(freelist); 953 954 if (!amd_iommu_v2_supported()) 955 return; 956 957 amd_iommu_unregister_ppr_notifier(&ppr_nb); 958 959 flush_workqueue(iommu_wq); 960 961 /* 962 * The loop below might call flush_workqueue(), so call 963 * destroy_workqueue() after it 964 */ 965 spin_lock_irqsave(&state_lock, flags); 966 967 list_for_each_entry_safe(dev_state, next, &state_list, list) { 968 WARN_ON_ONCE(1); 969 970 put_device_state(dev_state); 971 list_del(&dev_state->list); 972 list_add_tail(&dev_state->list, &freelist); 973 } 974 975 spin_unlock_irqrestore(&state_lock, flags); 976 977 /* 978 * Since free_device_state waits on the count to be zero, 979 * we need to free dev_state outside the spinlock. 980 */ 981 list_for_each_entry_safe(dev_state, next, &freelist, list) { 982 list_del(&dev_state->list); 983 free_device_state(dev_state); 984 } 985 986 destroy_workqueue(iommu_wq); 987 } 988 989 module_init(amd_iommu_v2_init); 990 module_exit(amd_iommu_v2_exit); 991