1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7 #define pr_fmt(fmt) "AMD-Vi: " fmt 8 9 #include <linux/refcount.h> 10 #include <linux/mmu_notifier.h> 11 #include <linux/amd-iommu.h> 12 #include <linux/mm_types.h> 13 #include <linux/profile.h> 14 #include <linux/module.h> 15 #include <linux/sched.h> 16 #include <linux/sched/mm.h> 17 #include <linux/wait.h> 18 #include <linux/pci.h> 19 #include <linux/gfp.h> 20 #include <linux/cc_platform.h> 21 22 #include "amd_iommu.h" 23 24 MODULE_LICENSE("GPL v2"); 25 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 26 27 #define PRI_QUEUE_SIZE 512 28 29 struct pri_queue { 30 atomic_t inflight; 31 bool finish; 32 int status; 33 }; 34 35 struct pasid_state { 36 struct list_head list; /* For global state-list */ 37 refcount_t count; /* Reference count */ 38 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 39 calls */ 40 struct mm_struct *mm; /* mm_struct for the faults */ 41 struct mmu_notifier mn; /* mmu_notifier handle */ 42 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 43 struct device_state *device_state; /* Link to our device_state */ 44 u32 pasid; /* PASID index */ 45 bool invalid; /* Used during setup and 46 teardown of the pasid */ 47 spinlock_t lock; /* Protect pri_queues and 48 mmu_notifer_count */ 49 wait_queue_head_t wq; /* To wait for count == 0 */ 50 }; 51 52 struct device_state { 53 struct list_head list; 54 u32 sbdf; 55 atomic_t count; 56 struct pci_dev *pdev; 57 struct pasid_state **states; 58 struct iommu_domain *domain; 59 int pasid_levels; 60 int max_pasids; 61 amd_iommu_invalid_ppr_cb inv_ppr_cb; 62 amd_iommu_invalidate_ctx inv_ctx_cb; 63 spinlock_t lock; 64 wait_queue_head_t wq; 65 }; 66 67 struct fault { 68 struct work_struct work; 69 struct device_state *dev_state; 70 struct pasid_state *state; 71 struct mm_struct *mm; 72 u64 address; 73 u32 pasid; 74 u16 tag; 75 u16 finish; 76 u16 flags; 77 }; 78 79 static LIST_HEAD(state_list); 80 static DEFINE_SPINLOCK(state_lock); 81 82 static struct workqueue_struct *iommu_wq; 83 84 static void free_pasid_states(struct device_state *dev_state); 85 86 static struct device_state *__get_device_state(u32 sbdf) 87 { 88 struct device_state *dev_state; 89 90 list_for_each_entry(dev_state, &state_list, list) { 91 if (dev_state->sbdf == sbdf) 92 return dev_state; 93 } 94 95 return NULL; 96 } 97 98 static struct device_state *get_device_state(u32 sbdf) 99 { 100 struct device_state *dev_state; 101 unsigned long flags; 102 103 spin_lock_irqsave(&state_lock, flags); 104 dev_state = __get_device_state(sbdf); 105 if (dev_state != NULL) 106 atomic_inc(&dev_state->count); 107 spin_unlock_irqrestore(&state_lock, flags); 108 109 return dev_state; 110 } 111 112 static void free_device_state(struct device_state *dev_state) 113 { 114 struct iommu_group *group; 115 116 /* Get rid of any remaining pasid states */ 117 free_pasid_states(dev_state); 118 119 /* 120 * Wait until the last reference is dropped before freeing 121 * the device state. 122 */ 123 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 124 125 /* 126 * First detach device from domain - No more PRI requests will arrive 127 * from that device after it is unbound from the IOMMUv2 domain. 128 */ 129 group = iommu_group_get(&dev_state->pdev->dev); 130 if (WARN_ON(!group)) 131 return; 132 133 iommu_detach_group(dev_state->domain, group); 134 135 iommu_group_put(group); 136 137 /* Everything is down now, free the IOMMUv2 domain */ 138 iommu_domain_free(dev_state->domain); 139 140 /* Finally get rid of the device-state */ 141 kfree(dev_state); 142 } 143 144 static void put_device_state(struct device_state *dev_state) 145 { 146 if (atomic_dec_and_test(&dev_state->count)) 147 wake_up(&dev_state->wq); 148 } 149 150 /* Must be called under dev_state->lock */ 151 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 152 u32 pasid, bool alloc) 153 { 154 struct pasid_state **root, **ptr; 155 int level, index; 156 157 level = dev_state->pasid_levels; 158 root = dev_state->states; 159 160 while (true) { 161 162 index = (pasid >> (9 * level)) & 0x1ff; 163 ptr = &root[index]; 164 165 if (level == 0) 166 break; 167 168 if (*ptr == NULL) { 169 if (!alloc) 170 return NULL; 171 172 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 173 if (*ptr == NULL) 174 return NULL; 175 } 176 177 root = (struct pasid_state **)*ptr; 178 level -= 1; 179 } 180 181 return ptr; 182 } 183 184 static int set_pasid_state(struct device_state *dev_state, 185 struct pasid_state *pasid_state, 186 u32 pasid) 187 { 188 struct pasid_state **ptr; 189 unsigned long flags; 190 int ret; 191 192 spin_lock_irqsave(&dev_state->lock, flags); 193 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 194 195 ret = -ENOMEM; 196 if (ptr == NULL) 197 goto out_unlock; 198 199 ret = -ENOMEM; 200 if (*ptr != NULL) 201 goto out_unlock; 202 203 *ptr = pasid_state; 204 205 ret = 0; 206 207 out_unlock: 208 spin_unlock_irqrestore(&dev_state->lock, flags); 209 210 return ret; 211 } 212 213 static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 214 { 215 struct pasid_state **ptr; 216 unsigned long flags; 217 218 spin_lock_irqsave(&dev_state->lock, flags); 219 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 220 221 if (ptr == NULL) 222 goto out_unlock; 223 224 *ptr = NULL; 225 226 out_unlock: 227 spin_unlock_irqrestore(&dev_state->lock, flags); 228 } 229 230 static struct pasid_state *get_pasid_state(struct device_state *dev_state, 231 u32 pasid) 232 { 233 struct pasid_state **ptr, *ret = NULL; 234 unsigned long flags; 235 236 spin_lock_irqsave(&dev_state->lock, flags); 237 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 238 239 if (ptr == NULL) 240 goto out_unlock; 241 242 ret = *ptr; 243 if (ret) 244 refcount_inc(&ret->count); 245 246 out_unlock: 247 spin_unlock_irqrestore(&dev_state->lock, flags); 248 249 return ret; 250 } 251 252 static void free_pasid_state(struct pasid_state *pasid_state) 253 { 254 kfree(pasid_state); 255 } 256 257 static void put_pasid_state(struct pasid_state *pasid_state) 258 { 259 if (refcount_dec_and_test(&pasid_state->count)) 260 wake_up(&pasid_state->wq); 261 } 262 263 static void put_pasid_state_wait(struct pasid_state *pasid_state) 264 { 265 refcount_dec(&pasid_state->count); 266 wait_event(pasid_state->wq, !refcount_read(&pasid_state->count)); 267 free_pasid_state(pasid_state); 268 } 269 270 static void unbind_pasid(struct pasid_state *pasid_state) 271 { 272 struct iommu_domain *domain; 273 274 domain = pasid_state->device_state->domain; 275 276 /* 277 * Mark pasid_state as invalid, no more faults will we added to the 278 * work queue after this is visible everywhere. 279 */ 280 pasid_state->invalid = true; 281 282 /* Make sure this is visible */ 283 smp_wmb(); 284 285 /* After this the device/pasid can't access the mm anymore */ 286 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 287 288 /* Make sure no more pending faults are in the queue */ 289 flush_workqueue(iommu_wq); 290 } 291 292 static void free_pasid_states_level1(struct pasid_state **tbl) 293 { 294 int i; 295 296 for (i = 0; i < 512; ++i) { 297 if (tbl[i] == NULL) 298 continue; 299 300 free_page((unsigned long)tbl[i]); 301 } 302 } 303 304 static void free_pasid_states_level2(struct pasid_state **tbl) 305 { 306 struct pasid_state **ptr; 307 int i; 308 309 for (i = 0; i < 512; ++i) { 310 if (tbl[i] == NULL) 311 continue; 312 313 ptr = (struct pasid_state **)tbl[i]; 314 free_pasid_states_level1(ptr); 315 } 316 } 317 318 static void free_pasid_states(struct device_state *dev_state) 319 { 320 struct pasid_state *pasid_state; 321 int i; 322 323 for (i = 0; i < dev_state->max_pasids; ++i) { 324 pasid_state = get_pasid_state(dev_state, i); 325 if (pasid_state == NULL) 326 continue; 327 328 put_pasid_state(pasid_state); 329 330 /* 331 * This will call the mn_release function and 332 * unbind the PASID 333 */ 334 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 335 336 put_pasid_state_wait(pasid_state); /* Reference taken in 337 amd_iommu_bind_pasid */ 338 339 /* Drop reference taken in amd_iommu_bind_pasid */ 340 put_device_state(dev_state); 341 } 342 343 if (dev_state->pasid_levels == 2) 344 free_pasid_states_level2(dev_state->states); 345 else if (dev_state->pasid_levels == 1) 346 free_pasid_states_level1(dev_state->states); 347 else 348 BUG_ON(dev_state->pasid_levels != 0); 349 350 free_page((unsigned long)dev_state->states); 351 } 352 353 static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 354 { 355 return container_of(mn, struct pasid_state, mn); 356 } 357 358 static void mn_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn, 359 struct mm_struct *mm, 360 unsigned long start, unsigned long end) 361 { 362 struct pasid_state *pasid_state; 363 struct device_state *dev_state; 364 365 pasid_state = mn_to_state(mn); 366 dev_state = pasid_state->device_state; 367 368 if ((start ^ (end - 1)) < PAGE_SIZE) 369 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 370 start); 371 else 372 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 373 } 374 375 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 376 { 377 struct pasid_state *pasid_state; 378 struct device_state *dev_state; 379 bool run_inv_ctx_cb; 380 381 might_sleep(); 382 383 pasid_state = mn_to_state(mn); 384 dev_state = pasid_state->device_state; 385 run_inv_ctx_cb = !pasid_state->invalid; 386 387 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 388 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 389 390 unbind_pasid(pasid_state); 391 } 392 393 static const struct mmu_notifier_ops iommu_mn = { 394 .release = mn_release, 395 .arch_invalidate_secondary_tlbs = mn_arch_invalidate_secondary_tlbs, 396 }; 397 398 static void set_pri_tag_status(struct pasid_state *pasid_state, 399 u16 tag, int status) 400 { 401 unsigned long flags; 402 403 spin_lock_irqsave(&pasid_state->lock, flags); 404 pasid_state->pri[tag].status = status; 405 spin_unlock_irqrestore(&pasid_state->lock, flags); 406 } 407 408 static void finish_pri_tag(struct device_state *dev_state, 409 struct pasid_state *pasid_state, 410 u16 tag) 411 { 412 unsigned long flags; 413 414 spin_lock_irqsave(&pasid_state->lock, flags); 415 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 416 pasid_state->pri[tag].finish) { 417 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 418 pasid_state->pri[tag].status, tag); 419 pasid_state->pri[tag].finish = false; 420 pasid_state->pri[tag].status = PPR_SUCCESS; 421 } 422 spin_unlock_irqrestore(&pasid_state->lock, flags); 423 } 424 425 static void handle_fault_error(struct fault *fault) 426 { 427 int status; 428 429 if (!fault->dev_state->inv_ppr_cb) { 430 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 431 return; 432 } 433 434 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 435 fault->pasid, 436 fault->address, 437 fault->flags); 438 switch (status) { 439 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 440 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 441 break; 442 case AMD_IOMMU_INV_PRI_RSP_INVALID: 443 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 444 break; 445 case AMD_IOMMU_INV_PRI_RSP_FAIL: 446 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 447 break; 448 default: 449 BUG(); 450 } 451 } 452 453 static bool access_error(struct vm_area_struct *vma, struct fault *fault) 454 { 455 unsigned long requested = 0; 456 457 if (fault->flags & PPR_FAULT_EXEC) 458 requested |= VM_EXEC; 459 460 if (fault->flags & PPR_FAULT_READ) 461 requested |= VM_READ; 462 463 if (fault->flags & PPR_FAULT_WRITE) 464 requested |= VM_WRITE; 465 466 return (requested & ~vma->vm_flags) != 0; 467 } 468 469 static void do_fault(struct work_struct *work) 470 { 471 struct fault *fault = container_of(work, struct fault, work); 472 struct vm_area_struct *vma; 473 vm_fault_t ret = VM_FAULT_ERROR; 474 unsigned int flags = 0; 475 struct mm_struct *mm; 476 u64 address; 477 478 mm = fault->state->mm; 479 address = fault->address; 480 481 if (fault->flags & PPR_FAULT_USER) 482 flags |= FAULT_FLAG_USER; 483 if (fault->flags & PPR_FAULT_WRITE) 484 flags |= FAULT_FLAG_WRITE; 485 flags |= FAULT_FLAG_REMOTE; 486 487 mmap_read_lock(mm); 488 vma = vma_lookup(mm, address); 489 if (!vma) 490 /* failed to get a vma in the right range */ 491 goto out; 492 493 /* Check if we have the right permissions on the vma */ 494 if (access_error(vma, fault)) 495 goto out; 496 497 ret = handle_mm_fault(vma, address, flags, NULL); 498 out: 499 mmap_read_unlock(mm); 500 501 if (ret & VM_FAULT_ERROR) 502 /* failed to service fault */ 503 handle_fault_error(fault); 504 505 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 506 507 put_pasid_state(fault->state); 508 509 kfree(fault); 510 } 511 512 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 513 { 514 struct amd_iommu_fault *iommu_fault; 515 struct pasid_state *pasid_state; 516 struct device_state *dev_state; 517 struct pci_dev *pdev = NULL; 518 unsigned long flags; 519 struct fault *fault; 520 bool finish; 521 u16 tag, devid, seg_id; 522 int ret; 523 524 iommu_fault = data; 525 tag = iommu_fault->tag & 0x1ff; 526 finish = (iommu_fault->tag >> 9) & 1; 527 528 seg_id = PCI_SBDF_TO_SEGID(iommu_fault->sbdf); 529 devid = PCI_SBDF_TO_DEVID(iommu_fault->sbdf); 530 pdev = pci_get_domain_bus_and_slot(seg_id, PCI_BUS_NUM(devid), 531 devid & 0xff); 532 if (!pdev) 533 return -ENODEV; 534 535 ret = NOTIFY_DONE; 536 537 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 538 if (amd_iommu_is_attach_deferred(&pdev->dev)) { 539 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 540 PPR_INVALID, tag); 541 goto out; 542 } 543 544 dev_state = get_device_state(iommu_fault->sbdf); 545 if (dev_state == NULL) 546 goto out; 547 548 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 549 if (pasid_state == NULL || pasid_state->invalid) { 550 /* We know the device but not the PASID -> send INVALID */ 551 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 552 PPR_INVALID, tag); 553 goto out_drop_state; 554 } 555 556 spin_lock_irqsave(&pasid_state->lock, flags); 557 atomic_inc(&pasid_state->pri[tag].inflight); 558 if (finish) 559 pasid_state->pri[tag].finish = true; 560 spin_unlock_irqrestore(&pasid_state->lock, flags); 561 562 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 563 if (fault == NULL) { 564 /* We are OOM - send success and let the device re-fault */ 565 finish_pri_tag(dev_state, pasid_state, tag); 566 goto out_drop_state; 567 } 568 569 fault->dev_state = dev_state; 570 fault->address = iommu_fault->address; 571 fault->state = pasid_state; 572 fault->tag = tag; 573 fault->finish = finish; 574 fault->pasid = iommu_fault->pasid; 575 fault->flags = iommu_fault->flags; 576 INIT_WORK(&fault->work, do_fault); 577 578 queue_work(iommu_wq, &fault->work); 579 580 ret = NOTIFY_OK; 581 582 out_drop_state: 583 584 if (ret != NOTIFY_OK && pasid_state) 585 put_pasid_state(pasid_state); 586 587 put_device_state(dev_state); 588 589 out: 590 pci_dev_put(pdev); 591 return ret; 592 } 593 594 static struct notifier_block ppr_nb = { 595 .notifier_call = ppr_notifier, 596 }; 597 598 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 599 struct task_struct *task) 600 { 601 struct pasid_state *pasid_state; 602 struct device_state *dev_state; 603 struct mm_struct *mm; 604 u32 sbdf; 605 int ret; 606 607 might_sleep(); 608 609 if (!amd_iommu_v2_supported()) 610 return -ENODEV; 611 612 sbdf = get_pci_sbdf_id(pdev); 613 dev_state = get_device_state(sbdf); 614 615 if (dev_state == NULL) 616 return -EINVAL; 617 618 ret = -EINVAL; 619 if (pasid >= dev_state->max_pasids) 620 goto out; 621 622 ret = -ENOMEM; 623 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 624 if (pasid_state == NULL) 625 goto out; 626 627 628 refcount_set(&pasid_state->count, 1); 629 init_waitqueue_head(&pasid_state->wq); 630 spin_lock_init(&pasid_state->lock); 631 632 mm = get_task_mm(task); 633 pasid_state->mm = mm; 634 pasid_state->device_state = dev_state; 635 pasid_state->pasid = pasid; 636 pasid_state->invalid = true; /* Mark as valid only if we are 637 done with setting up the pasid */ 638 pasid_state->mn.ops = &iommu_mn; 639 640 if (pasid_state->mm == NULL) 641 goto out_free; 642 643 ret = mmu_notifier_register(&pasid_state->mn, mm); 644 if (ret) 645 goto out_free; 646 647 ret = set_pasid_state(dev_state, pasid_state, pasid); 648 if (ret) 649 goto out_unregister; 650 651 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 652 __pa(pasid_state->mm->pgd)); 653 if (ret) 654 goto out_clear_state; 655 656 /* Now we are ready to handle faults */ 657 pasid_state->invalid = false; 658 659 /* 660 * Drop the reference to the mm_struct here. We rely on the 661 * mmu_notifier release call-back to inform us when the mm 662 * is going away. 663 */ 664 mmput(mm); 665 666 return 0; 667 668 out_clear_state: 669 clear_pasid_state(dev_state, pasid); 670 671 out_unregister: 672 mmu_notifier_unregister(&pasid_state->mn, mm); 673 mmput(mm); 674 675 out_free: 676 free_pasid_state(pasid_state); 677 678 out: 679 put_device_state(dev_state); 680 681 return ret; 682 } 683 EXPORT_SYMBOL(amd_iommu_bind_pasid); 684 685 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 686 { 687 struct pasid_state *pasid_state; 688 struct device_state *dev_state; 689 u32 sbdf; 690 691 might_sleep(); 692 693 if (!amd_iommu_v2_supported()) 694 return; 695 696 sbdf = get_pci_sbdf_id(pdev); 697 dev_state = get_device_state(sbdf); 698 if (dev_state == NULL) 699 return; 700 701 if (pasid >= dev_state->max_pasids) 702 goto out; 703 704 pasid_state = get_pasid_state(dev_state, pasid); 705 if (pasid_state == NULL) 706 goto out; 707 /* 708 * Drop reference taken here. We are safe because we still hold 709 * the reference taken in the amd_iommu_bind_pasid function. 710 */ 711 put_pasid_state(pasid_state); 712 713 /* Clear the pasid state so that the pasid can be re-used */ 714 clear_pasid_state(dev_state, pasid_state->pasid); 715 716 /* 717 * Call mmu_notifier_unregister to drop our reference 718 * to pasid_state->mm 719 */ 720 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 721 722 put_pasid_state_wait(pasid_state); /* Reference taken in 723 amd_iommu_bind_pasid */ 724 out: 725 /* Drop reference taken in this function */ 726 put_device_state(dev_state); 727 728 /* Drop reference taken in amd_iommu_bind_pasid */ 729 put_device_state(dev_state); 730 } 731 EXPORT_SYMBOL(amd_iommu_unbind_pasid); 732 733 int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 734 { 735 struct device_state *dev_state; 736 struct iommu_group *group; 737 unsigned long flags; 738 int ret, tmp; 739 u32 sbdf; 740 741 might_sleep(); 742 743 /* 744 * When memory encryption is active the device is likely not in a 745 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 746 */ 747 if (cc_platform_has(CC_ATTR_MEM_ENCRYPT)) 748 return -ENODEV; 749 750 if (!amd_iommu_v2_supported()) 751 return -ENODEV; 752 753 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 754 return -EINVAL; 755 756 sbdf = get_pci_sbdf_id(pdev); 757 758 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 759 if (dev_state == NULL) 760 return -ENOMEM; 761 762 spin_lock_init(&dev_state->lock); 763 init_waitqueue_head(&dev_state->wq); 764 dev_state->pdev = pdev; 765 dev_state->sbdf = sbdf; 766 767 tmp = pasids; 768 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 769 dev_state->pasid_levels += 1; 770 771 atomic_set(&dev_state->count, 1); 772 dev_state->max_pasids = pasids; 773 774 ret = -ENOMEM; 775 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 776 if (dev_state->states == NULL) 777 goto out_free_dev_state; 778 779 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 780 if (dev_state->domain == NULL) 781 goto out_free_states; 782 783 /* See iommu_is_default_domain() */ 784 dev_state->domain->type = IOMMU_DOMAIN_IDENTITY; 785 amd_iommu_domain_direct_map(dev_state->domain); 786 787 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 788 if (ret) 789 goto out_free_domain; 790 791 group = iommu_group_get(&pdev->dev); 792 if (!group) { 793 ret = -EINVAL; 794 goto out_free_domain; 795 } 796 797 ret = iommu_attach_group(dev_state->domain, group); 798 if (ret != 0) 799 goto out_drop_group; 800 801 iommu_group_put(group); 802 803 spin_lock_irqsave(&state_lock, flags); 804 805 if (__get_device_state(sbdf) != NULL) { 806 spin_unlock_irqrestore(&state_lock, flags); 807 ret = -EBUSY; 808 goto out_free_domain; 809 } 810 811 list_add_tail(&dev_state->list, &state_list); 812 813 spin_unlock_irqrestore(&state_lock, flags); 814 815 return 0; 816 817 out_drop_group: 818 iommu_group_put(group); 819 820 out_free_domain: 821 iommu_domain_free(dev_state->domain); 822 823 out_free_states: 824 free_page((unsigned long)dev_state->states); 825 826 out_free_dev_state: 827 kfree(dev_state); 828 829 return ret; 830 } 831 EXPORT_SYMBOL(amd_iommu_init_device); 832 833 void amd_iommu_free_device(struct pci_dev *pdev) 834 { 835 struct device_state *dev_state; 836 unsigned long flags; 837 u32 sbdf; 838 839 if (!amd_iommu_v2_supported()) 840 return; 841 842 sbdf = get_pci_sbdf_id(pdev); 843 844 spin_lock_irqsave(&state_lock, flags); 845 846 dev_state = __get_device_state(sbdf); 847 if (dev_state == NULL) { 848 spin_unlock_irqrestore(&state_lock, flags); 849 return; 850 } 851 852 list_del(&dev_state->list); 853 854 spin_unlock_irqrestore(&state_lock, flags); 855 856 put_device_state(dev_state); 857 free_device_state(dev_state); 858 } 859 EXPORT_SYMBOL(amd_iommu_free_device); 860 861 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 862 amd_iommu_invalid_ppr_cb cb) 863 { 864 struct device_state *dev_state; 865 unsigned long flags; 866 u32 sbdf; 867 int ret; 868 869 if (!amd_iommu_v2_supported()) 870 return -ENODEV; 871 872 sbdf = get_pci_sbdf_id(pdev); 873 874 spin_lock_irqsave(&state_lock, flags); 875 876 ret = -EINVAL; 877 dev_state = __get_device_state(sbdf); 878 if (dev_state == NULL) 879 goto out_unlock; 880 881 dev_state->inv_ppr_cb = cb; 882 883 ret = 0; 884 885 out_unlock: 886 spin_unlock_irqrestore(&state_lock, flags); 887 888 return ret; 889 } 890 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 891 892 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 893 amd_iommu_invalidate_ctx cb) 894 { 895 struct device_state *dev_state; 896 unsigned long flags; 897 u32 sbdf; 898 int ret; 899 900 if (!amd_iommu_v2_supported()) 901 return -ENODEV; 902 903 sbdf = get_pci_sbdf_id(pdev); 904 905 spin_lock_irqsave(&state_lock, flags); 906 907 ret = -EINVAL; 908 dev_state = __get_device_state(sbdf); 909 if (dev_state == NULL) 910 goto out_unlock; 911 912 dev_state->inv_ctx_cb = cb; 913 914 ret = 0; 915 916 out_unlock: 917 spin_unlock_irqrestore(&state_lock, flags); 918 919 return ret; 920 } 921 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 922 923 static int __init amd_iommu_v2_init(void) 924 { 925 int ret; 926 927 if (!amd_iommu_v2_supported()) { 928 pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n"); 929 /* 930 * Load anyway to provide the symbols to other modules 931 * which may use AMD IOMMUv2 optionally. 932 */ 933 return 0; 934 } 935 936 ret = -ENOMEM; 937 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 938 if (iommu_wq == NULL) 939 goto out; 940 941 amd_iommu_register_ppr_notifier(&ppr_nb); 942 943 pr_info("AMD IOMMUv2 loaded and initialized\n"); 944 945 return 0; 946 947 out: 948 return ret; 949 } 950 951 static void __exit amd_iommu_v2_exit(void) 952 { 953 struct device_state *dev_state, *next; 954 unsigned long flags; 955 LIST_HEAD(freelist); 956 957 if (!amd_iommu_v2_supported()) 958 return; 959 960 amd_iommu_unregister_ppr_notifier(&ppr_nb); 961 962 flush_workqueue(iommu_wq); 963 964 /* 965 * The loop below might call flush_workqueue(), so call 966 * destroy_workqueue() after it 967 */ 968 spin_lock_irqsave(&state_lock, flags); 969 970 list_for_each_entry_safe(dev_state, next, &state_list, list) { 971 WARN_ON_ONCE(1); 972 973 put_device_state(dev_state); 974 list_del(&dev_state->list); 975 list_add_tail(&dev_state->list, &freelist); 976 } 977 978 spin_unlock_irqrestore(&state_lock, flags); 979 980 /* 981 * Since free_device_state waits on the count to be zero, 982 * we need to free dev_state outside the spinlock. 983 */ 984 list_for_each_entry_safe(dev_state, next, &freelist, list) { 985 list_del(&dev_state->list); 986 free_device_state(dev_state); 987 } 988 989 destroy_workqueue(iommu_wq); 990 } 991 992 module_init(amd_iommu_v2_init); 993 module_exit(amd_iommu_v2_exit); 994