1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7 #define pr_fmt(fmt) "AMD-Vi: " fmt 8 9 #include <linux/refcount.h> 10 #include <linux/mmu_notifier.h> 11 #include <linux/amd-iommu.h> 12 #include <linux/mm_types.h> 13 #include <linux/profile.h> 14 #include <linux/module.h> 15 #include <linux/sched.h> 16 #include <linux/sched/mm.h> 17 #include <linux/wait.h> 18 #include <linux/pci.h> 19 #include <linux/gfp.h> 20 #include <linux/cc_platform.h> 21 22 #include "amd_iommu.h" 23 24 MODULE_LICENSE("GPL v2"); 25 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 26 27 #define PRI_QUEUE_SIZE 512 28 29 struct pri_queue { 30 atomic_t inflight; 31 bool finish; 32 int status; 33 }; 34 35 struct pasid_state { 36 struct list_head list; /* For global state-list */ 37 refcount_t count; /* Reference count */ 38 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 39 calls */ 40 struct mm_struct *mm; /* mm_struct for the faults */ 41 struct mmu_notifier mn; /* mmu_notifier handle */ 42 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 43 struct device_state *device_state; /* Link to our device_state */ 44 u32 pasid; /* PASID index */ 45 bool invalid; /* Used during setup and 46 teardown of the pasid */ 47 spinlock_t lock; /* Protect pri_queues and 48 mmu_notifer_count */ 49 wait_queue_head_t wq; /* To wait for count == 0 */ 50 }; 51 52 struct device_state { 53 struct list_head list; 54 u32 sbdf; 55 atomic_t count; 56 struct pci_dev *pdev; 57 struct pasid_state **states; 58 struct iommu_domain *domain; 59 int pasid_levels; 60 int max_pasids; 61 amd_iommu_invalid_ppr_cb inv_ppr_cb; 62 amd_iommu_invalidate_ctx inv_ctx_cb; 63 spinlock_t lock; 64 wait_queue_head_t wq; 65 }; 66 67 struct fault { 68 struct work_struct work; 69 struct device_state *dev_state; 70 struct pasid_state *state; 71 struct mm_struct *mm; 72 u64 address; 73 u32 pasid; 74 u16 tag; 75 u16 finish; 76 u16 flags; 77 }; 78 79 static LIST_HEAD(state_list); 80 static DEFINE_SPINLOCK(state_lock); 81 82 static struct workqueue_struct *iommu_wq; 83 84 static void free_pasid_states(struct device_state *dev_state); 85 86 static struct device_state *__get_device_state(u32 sbdf) 87 { 88 struct device_state *dev_state; 89 90 list_for_each_entry(dev_state, &state_list, list) { 91 if (dev_state->sbdf == sbdf) 92 return dev_state; 93 } 94 95 return NULL; 96 } 97 98 static struct device_state *get_device_state(u32 sbdf) 99 { 100 struct device_state *dev_state; 101 unsigned long flags; 102 103 spin_lock_irqsave(&state_lock, flags); 104 dev_state = __get_device_state(sbdf); 105 if (dev_state != NULL) 106 atomic_inc(&dev_state->count); 107 spin_unlock_irqrestore(&state_lock, flags); 108 109 return dev_state; 110 } 111 112 static void free_device_state(struct device_state *dev_state) 113 { 114 struct iommu_group *group; 115 116 /* Get rid of any remaining pasid states */ 117 free_pasid_states(dev_state); 118 119 /* 120 * Wait until the last reference is dropped before freeing 121 * the device state. 122 */ 123 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 124 125 /* 126 * First detach device from domain - No more PRI requests will arrive 127 * from that device after it is unbound from the IOMMUv2 domain. 128 */ 129 group = iommu_group_get(&dev_state->pdev->dev); 130 if (WARN_ON(!group)) 131 return; 132 133 iommu_detach_group(dev_state->domain, group); 134 135 iommu_group_put(group); 136 137 /* Everything is down now, free the IOMMUv2 domain */ 138 iommu_domain_free(dev_state->domain); 139 140 /* Finally get rid of the device-state */ 141 kfree(dev_state); 142 } 143 144 static void put_device_state(struct device_state *dev_state) 145 { 146 if (atomic_dec_and_test(&dev_state->count)) 147 wake_up(&dev_state->wq); 148 } 149 150 /* Must be called under dev_state->lock */ 151 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 152 u32 pasid, bool alloc) 153 { 154 struct pasid_state **root, **ptr; 155 int level, index; 156 157 level = dev_state->pasid_levels; 158 root = dev_state->states; 159 160 while (true) { 161 162 index = (pasid >> (9 * level)) & 0x1ff; 163 ptr = &root[index]; 164 165 if (level == 0) 166 break; 167 168 if (*ptr == NULL) { 169 if (!alloc) 170 return NULL; 171 172 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 173 if (*ptr == NULL) 174 return NULL; 175 } 176 177 root = (struct pasid_state **)*ptr; 178 level -= 1; 179 } 180 181 return ptr; 182 } 183 184 static int set_pasid_state(struct device_state *dev_state, 185 struct pasid_state *pasid_state, 186 u32 pasid) 187 { 188 struct pasid_state **ptr; 189 unsigned long flags; 190 int ret; 191 192 spin_lock_irqsave(&dev_state->lock, flags); 193 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 194 195 ret = -ENOMEM; 196 if (ptr == NULL) 197 goto out_unlock; 198 199 ret = -ENOMEM; 200 if (*ptr != NULL) 201 goto out_unlock; 202 203 *ptr = pasid_state; 204 205 ret = 0; 206 207 out_unlock: 208 spin_unlock_irqrestore(&dev_state->lock, flags); 209 210 return ret; 211 } 212 213 static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 214 { 215 struct pasid_state **ptr; 216 unsigned long flags; 217 218 spin_lock_irqsave(&dev_state->lock, flags); 219 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 220 221 if (ptr == NULL) 222 goto out_unlock; 223 224 *ptr = NULL; 225 226 out_unlock: 227 spin_unlock_irqrestore(&dev_state->lock, flags); 228 } 229 230 static struct pasid_state *get_pasid_state(struct device_state *dev_state, 231 u32 pasid) 232 { 233 struct pasid_state **ptr, *ret = NULL; 234 unsigned long flags; 235 236 spin_lock_irqsave(&dev_state->lock, flags); 237 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 238 239 if (ptr == NULL) 240 goto out_unlock; 241 242 ret = *ptr; 243 if (ret) 244 refcount_inc(&ret->count); 245 246 out_unlock: 247 spin_unlock_irqrestore(&dev_state->lock, flags); 248 249 return ret; 250 } 251 252 static void free_pasid_state(struct pasid_state *pasid_state) 253 { 254 kfree(pasid_state); 255 } 256 257 static void put_pasid_state(struct pasid_state *pasid_state) 258 { 259 if (refcount_dec_and_test(&pasid_state->count)) 260 wake_up(&pasid_state->wq); 261 } 262 263 static void put_pasid_state_wait(struct pasid_state *pasid_state) 264 { 265 if (!refcount_dec_and_test(&pasid_state->count)) 266 wait_event(pasid_state->wq, !refcount_read(&pasid_state->count)); 267 free_pasid_state(pasid_state); 268 } 269 270 static void unbind_pasid(struct pasid_state *pasid_state) 271 { 272 struct iommu_domain *domain; 273 274 domain = pasid_state->device_state->domain; 275 276 /* 277 * Mark pasid_state as invalid, no more faults will we added to the 278 * work queue after this is visible everywhere. 279 */ 280 pasid_state->invalid = true; 281 282 /* Make sure this is visible */ 283 smp_wmb(); 284 285 /* After this the device/pasid can't access the mm anymore */ 286 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 287 288 /* Make sure no more pending faults are in the queue */ 289 flush_workqueue(iommu_wq); 290 } 291 292 static void free_pasid_states_level1(struct pasid_state **tbl) 293 { 294 int i; 295 296 for (i = 0; i < 512; ++i) { 297 if (tbl[i] == NULL) 298 continue; 299 300 free_page((unsigned long)tbl[i]); 301 } 302 } 303 304 static void free_pasid_states_level2(struct pasid_state **tbl) 305 { 306 struct pasid_state **ptr; 307 int i; 308 309 for (i = 0; i < 512; ++i) { 310 if (tbl[i] == NULL) 311 continue; 312 313 ptr = (struct pasid_state **)tbl[i]; 314 free_pasid_states_level1(ptr); 315 } 316 } 317 318 static void free_pasid_states(struct device_state *dev_state) 319 { 320 struct pasid_state *pasid_state; 321 int i; 322 323 for (i = 0; i < dev_state->max_pasids; ++i) { 324 pasid_state = get_pasid_state(dev_state, i); 325 if (pasid_state == NULL) 326 continue; 327 328 put_pasid_state(pasid_state); 329 330 /* Clear the pasid state so that the pasid can be re-used */ 331 clear_pasid_state(dev_state, pasid_state->pasid); 332 333 /* 334 * This will call the mn_release function and 335 * unbind the PASID 336 */ 337 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 338 339 put_pasid_state_wait(pasid_state); /* Reference taken in 340 amd_iommu_bind_pasid */ 341 342 /* Drop reference taken in amd_iommu_bind_pasid */ 343 put_device_state(dev_state); 344 } 345 346 if (dev_state->pasid_levels == 2) 347 free_pasid_states_level2(dev_state->states); 348 else if (dev_state->pasid_levels == 1) 349 free_pasid_states_level1(dev_state->states); 350 else 351 BUG_ON(dev_state->pasid_levels != 0); 352 353 free_page((unsigned long)dev_state->states); 354 } 355 356 static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 357 { 358 return container_of(mn, struct pasid_state, mn); 359 } 360 361 static void mn_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn, 362 struct mm_struct *mm, 363 unsigned long start, unsigned long end) 364 { 365 struct pasid_state *pasid_state; 366 struct device_state *dev_state; 367 368 pasid_state = mn_to_state(mn); 369 dev_state = pasid_state->device_state; 370 371 if ((start ^ (end - 1)) < PAGE_SIZE) 372 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 373 start); 374 else 375 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 376 } 377 378 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 379 { 380 struct pasid_state *pasid_state; 381 struct device_state *dev_state; 382 bool run_inv_ctx_cb; 383 384 might_sleep(); 385 386 pasid_state = mn_to_state(mn); 387 dev_state = pasid_state->device_state; 388 run_inv_ctx_cb = !pasid_state->invalid; 389 390 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 391 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 392 393 unbind_pasid(pasid_state); 394 } 395 396 static const struct mmu_notifier_ops iommu_mn = { 397 .release = mn_release, 398 .arch_invalidate_secondary_tlbs = mn_arch_invalidate_secondary_tlbs, 399 }; 400 401 static void set_pri_tag_status(struct pasid_state *pasid_state, 402 u16 tag, int status) 403 { 404 unsigned long flags; 405 406 spin_lock_irqsave(&pasid_state->lock, flags); 407 pasid_state->pri[tag].status = status; 408 spin_unlock_irqrestore(&pasid_state->lock, flags); 409 } 410 411 static void finish_pri_tag(struct device_state *dev_state, 412 struct pasid_state *pasid_state, 413 u16 tag) 414 { 415 unsigned long flags; 416 417 spin_lock_irqsave(&pasid_state->lock, flags); 418 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 419 pasid_state->pri[tag].finish) { 420 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 421 pasid_state->pri[tag].status, tag); 422 pasid_state->pri[tag].finish = false; 423 pasid_state->pri[tag].status = PPR_SUCCESS; 424 } 425 spin_unlock_irqrestore(&pasid_state->lock, flags); 426 } 427 428 static void handle_fault_error(struct fault *fault) 429 { 430 int status; 431 432 if (!fault->dev_state->inv_ppr_cb) { 433 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 434 return; 435 } 436 437 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 438 fault->pasid, 439 fault->address, 440 fault->flags); 441 switch (status) { 442 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 443 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 444 break; 445 case AMD_IOMMU_INV_PRI_RSP_INVALID: 446 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 447 break; 448 case AMD_IOMMU_INV_PRI_RSP_FAIL: 449 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 450 break; 451 default: 452 BUG(); 453 } 454 } 455 456 static bool access_error(struct vm_area_struct *vma, struct fault *fault) 457 { 458 unsigned long requested = 0; 459 460 if (fault->flags & PPR_FAULT_EXEC) 461 requested |= VM_EXEC; 462 463 if (fault->flags & PPR_FAULT_READ) 464 requested |= VM_READ; 465 466 if (fault->flags & PPR_FAULT_WRITE) 467 requested |= VM_WRITE; 468 469 return (requested & ~vma->vm_flags) != 0; 470 } 471 472 static void do_fault(struct work_struct *work) 473 { 474 struct fault *fault = container_of(work, struct fault, work); 475 struct vm_area_struct *vma; 476 vm_fault_t ret = VM_FAULT_ERROR; 477 unsigned int flags = 0; 478 struct mm_struct *mm; 479 u64 address; 480 481 mm = fault->state->mm; 482 address = fault->address; 483 484 if (fault->flags & PPR_FAULT_USER) 485 flags |= FAULT_FLAG_USER; 486 if (fault->flags & PPR_FAULT_WRITE) 487 flags |= FAULT_FLAG_WRITE; 488 flags |= FAULT_FLAG_REMOTE; 489 490 mmap_read_lock(mm); 491 vma = vma_lookup(mm, address); 492 if (!vma) 493 /* failed to get a vma in the right range */ 494 goto out; 495 496 /* Check if we have the right permissions on the vma */ 497 if (access_error(vma, fault)) 498 goto out; 499 500 ret = handle_mm_fault(vma, address, flags, NULL); 501 out: 502 mmap_read_unlock(mm); 503 504 if (ret & VM_FAULT_ERROR) 505 /* failed to service fault */ 506 handle_fault_error(fault); 507 508 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 509 510 put_pasid_state(fault->state); 511 512 kfree(fault); 513 } 514 515 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 516 { 517 struct amd_iommu_fault *iommu_fault; 518 struct pasid_state *pasid_state; 519 struct device_state *dev_state; 520 struct pci_dev *pdev = NULL; 521 unsigned long flags; 522 struct fault *fault; 523 bool finish; 524 u16 tag, devid, seg_id; 525 int ret; 526 527 iommu_fault = data; 528 tag = iommu_fault->tag & 0x1ff; 529 finish = (iommu_fault->tag >> 9) & 1; 530 531 seg_id = PCI_SBDF_TO_SEGID(iommu_fault->sbdf); 532 devid = PCI_SBDF_TO_DEVID(iommu_fault->sbdf); 533 pdev = pci_get_domain_bus_and_slot(seg_id, PCI_BUS_NUM(devid), 534 devid & 0xff); 535 if (!pdev) 536 return -ENODEV; 537 538 ret = NOTIFY_DONE; 539 540 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 541 if (amd_iommu_is_attach_deferred(&pdev->dev)) { 542 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 543 PPR_INVALID, tag); 544 goto out; 545 } 546 547 dev_state = get_device_state(iommu_fault->sbdf); 548 if (dev_state == NULL) 549 goto out; 550 551 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 552 if (pasid_state == NULL || pasid_state->invalid) { 553 /* We know the device but not the PASID -> send INVALID */ 554 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 555 PPR_INVALID, tag); 556 goto out_drop_state; 557 } 558 559 spin_lock_irqsave(&pasid_state->lock, flags); 560 atomic_inc(&pasid_state->pri[tag].inflight); 561 if (finish) 562 pasid_state->pri[tag].finish = true; 563 spin_unlock_irqrestore(&pasid_state->lock, flags); 564 565 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 566 if (fault == NULL) { 567 /* We are OOM - send success and let the device re-fault */ 568 finish_pri_tag(dev_state, pasid_state, tag); 569 goto out_drop_state; 570 } 571 572 fault->dev_state = dev_state; 573 fault->address = iommu_fault->address; 574 fault->state = pasid_state; 575 fault->tag = tag; 576 fault->finish = finish; 577 fault->pasid = iommu_fault->pasid; 578 fault->flags = iommu_fault->flags; 579 INIT_WORK(&fault->work, do_fault); 580 581 queue_work(iommu_wq, &fault->work); 582 583 ret = NOTIFY_OK; 584 585 out_drop_state: 586 587 if (ret != NOTIFY_OK && pasid_state) 588 put_pasid_state(pasid_state); 589 590 put_device_state(dev_state); 591 592 out: 593 pci_dev_put(pdev); 594 return ret; 595 } 596 597 static struct notifier_block ppr_nb = { 598 .notifier_call = ppr_notifier, 599 }; 600 601 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 602 struct task_struct *task) 603 { 604 struct pasid_state *pasid_state; 605 struct device_state *dev_state; 606 struct mm_struct *mm; 607 u32 sbdf; 608 int ret; 609 610 might_sleep(); 611 612 if (!amd_iommu_v2_supported()) 613 return -ENODEV; 614 615 sbdf = get_pci_sbdf_id(pdev); 616 dev_state = get_device_state(sbdf); 617 618 if (dev_state == NULL) 619 return -EINVAL; 620 621 ret = -EINVAL; 622 if (pasid >= dev_state->max_pasids) 623 goto out; 624 625 ret = -ENOMEM; 626 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 627 if (pasid_state == NULL) 628 goto out; 629 630 631 refcount_set(&pasid_state->count, 1); 632 init_waitqueue_head(&pasid_state->wq); 633 spin_lock_init(&pasid_state->lock); 634 635 mm = get_task_mm(task); 636 pasid_state->mm = mm; 637 pasid_state->device_state = dev_state; 638 pasid_state->pasid = pasid; 639 pasid_state->invalid = true; /* Mark as valid only if we are 640 done with setting up the pasid */ 641 pasid_state->mn.ops = &iommu_mn; 642 643 if (pasid_state->mm == NULL) 644 goto out_free; 645 646 ret = mmu_notifier_register(&pasid_state->mn, mm); 647 if (ret) 648 goto out_free; 649 650 ret = set_pasid_state(dev_state, pasid_state, pasid); 651 if (ret) 652 goto out_unregister; 653 654 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 655 __pa(pasid_state->mm->pgd)); 656 if (ret) 657 goto out_clear_state; 658 659 /* Now we are ready to handle faults */ 660 pasid_state->invalid = false; 661 662 /* 663 * Drop the reference to the mm_struct here. We rely on the 664 * mmu_notifier release call-back to inform us when the mm 665 * is going away. 666 */ 667 mmput(mm); 668 669 return 0; 670 671 out_clear_state: 672 clear_pasid_state(dev_state, pasid); 673 674 out_unregister: 675 mmu_notifier_unregister(&pasid_state->mn, mm); 676 mmput(mm); 677 678 out_free: 679 free_pasid_state(pasid_state); 680 681 out: 682 put_device_state(dev_state); 683 684 return ret; 685 } 686 EXPORT_SYMBOL(amd_iommu_bind_pasid); 687 688 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 689 { 690 struct pasid_state *pasid_state; 691 struct device_state *dev_state; 692 u32 sbdf; 693 694 might_sleep(); 695 696 if (!amd_iommu_v2_supported()) 697 return; 698 699 sbdf = get_pci_sbdf_id(pdev); 700 dev_state = get_device_state(sbdf); 701 if (dev_state == NULL) 702 return; 703 704 if (pasid >= dev_state->max_pasids) 705 goto out; 706 707 pasid_state = get_pasid_state(dev_state, pasid); 708 if (pasid_state == NULL) 709 goto out; 710 /* 711 * Drop reference taken here. We are safe because we still hold 712 * the reference taken in the amd_iommu_bind_pasid function. 713 */ 714 put_pasid_state(pasid_state); 715 716 /* Clear the pasid state so that the pasid can be re-used */ 717 clear_pasid_state(dev_state, pasid_state->pasid); 718 719 /* 720 * Call mmu_notifier_unregister to drop our reference 721 * to pasid_state->mm 722 */ 723 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 724 725 put_pasid_state_wait(pasid_state); /* Reference taken in 726 amd_iommu_bind_pasid */ 727 out: 728 /* Drop reference taken in this function */ 729 put_device_state(dev_state); 730 731 /* Drop reference taken in amd_iommu_bind_pasid */ 732 put_device_state(dev_state); 733 } 734 EXPORT_SYMBOL(amd_iommu_unbind_pasid); 735 736 int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 737 { 738 struct device_state *dev_state; 739 struct iommu_group *group; 740 unsigned long flags; 741 int ret, tmp; 742 u32 sbdf; 743 744 might_sleep(); 745 746 /* 747 * When memory encryption is active the device is likely not in a 748 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 749 */ 750 if (cc_platform_has(CC_ATTR_MEM_ENCRYPT)) 751 return -ENODEV; 752 753 if (!amd_iommu_v2_supported()) 754 return -ENODEV; 755 756 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 757 return -EINVAL; 758 759 sbdf = get_pci_sbdf_id(pdev); 760 761 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 762 if (dev_state == NULL) 763 return -ENOMEM; 764 765 spin_lock_init(&dev_state->lock); 766 init_waitqueue_head(&dev_state->wq); 767 dev_state->pdev = pdev; 768 dev_state->sbdf = sbdf; 769 770 tmp = pasids; 771 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 772 dev_state->pasid_levels += 1; 773 774 atomic_set(&dev_state->count, 1); 775 dev_state->max_pasids = pasids; 776 777 ret = -ENOMEM; 778 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 779 if (dev_state->states == NULL) 780 goto out_free_dev_state; 781 782 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 783 if (dev_state->domain == NULL) 784 goto out_free_states; 785 786 /* See iommu_is_default_domain() */ 787 dev_state->domain->type = IOMMU_DOMAIN_IDENTITY; 788 amd_iommu_domain_direct_map(dev_state->domain); 789 790 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 791 if (ret) 792 goto out_free_domain; 793 794 group = iommu_group_get(&pdev->dev); 795 if (!group) { 796 ret = -EINVAL; 797 goto out_free_domain; 798 } 799 800 ret = iommu_attach_group(dev_state->domain, group); 801 if (ret != 0) 802 goto out_drop_group; 803 804 iommu_group_put(group); 805 806 spin_lock_irqsave(&state_lock, flags); 807 808 if (__get_device_state(sbdf) != NULL) { 809 spin_unlock_irqrestore(&state_lock, flags); 810 ret = -EBUSY; 811 goto out_free_domain; 812 } 813 814 list_add_tail(&dev_state->list, &state_list); 815 816 spin_unlock_irqrestore(&state_lock, flags); 817 818 return 0; 819 820 out_drop_group: 821 iommu_group_put(group); 822 823 out_free_domain: 824 iommu_domain_free(dev_state->domain); 825 826 out_free_states: 827 free_page((unsigned long)dev_state->states); 828 829 out_free_dev_state: 830 kfree(dev_state); 831 832 return ret; 833 } 834 EXPORT_SYMBOL(amd_iommu_init_device); 835 836 void amd_iommu_free_device(struct pci_dev *pdev) 837 { 838 struct device_state *dev_state; 839 unsigned long flags; 840 u32 sbdf; 841 842 if (!amd_iommu_v2_supported()) 843 return; 844 845 sbdf = get_pci_sbdf_id(pdev); 846 847 spin_lock_irqsave(&state_lock, flags); 848 849 dev_state = __get_device_state(sbdf); 850 if (dev_state == NULL) { 851 spin_unlock_irqrestore(&state_lock, flags); 852 return; 853 } 854 855 list_del(&dev_state->list); 856 857 spin_unlock_irqrestore(&state_lock, flags); 858 859 put_device_state(dev_state); 860 free_device_state(dev_state); 861 } 862 EXPORT_SYMBOL(amd_iommu_free_device); 863 864 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 865 amd_iommu_invalid_ppr_cb cb) 866 { 867 struct device_state *dev_state; 868 unsigned long flags; 869 u32 sbdf; 870 int ret; 871 872 if (!amd_iommu_v2_supported()) 873 return -ENODEV; 874 875 sbdf = get_pci_sbdf_id(pdev); 876 877 spin_lock_irqsave(&state_lock, flags); 878 879 ret = -EINVAL; 880 dev_state = __get_device_state(sbdf); 881 if (dev_state == NULL) 882 goto out_unlock; 883 884 dev_state->inv_ppr_cb = cb; 885 886 ret = 0; 887 888 out_unlock: 889 spin_unlock_irqrestore(&state_lock, flags); 890 891 return ret; 892 } 893 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 894 895 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 896 amd_iommu_invalidate_ctx cb) 897 { 898 struct device_state *dev_state; 899 unsigned long flags; 900 u32 sbdf; 901 int ret; 902 903 if (!amd_iommu_v2_supported()) 904 return -ENODEV; 905 906 sbdf = get_pci_sbdf_id(pdev); 907 908 spin_lock_irqsave(&state_lock, flags); 909 910 ret = -EINVAL; 911 dev_state = __get_device_state(sbdf); 912 if (dev_state == NULL) 913 goto out_unlock; 914 915 dev_state->inv_ctx_cb = cb; 916 917 ret = 0; 918 919 out_unlock: 920 spin_unlock_irqrestore(&state_lock, flags); 921 922 return ret; 923 } 924 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 925 926 static int __init amd_iommu_v2_init(void) 927 { 928 int ret; 929 930 if (!amd_iommu_v2_supported()) { 931 pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n"); 932 /* 933 * Load anyway to provide the symbols to other modules 934 * which may use AMD IOMMUv2 optionally. 935 */ 936 return 0; 937 } 938 939 ret = -ENOMEM; 940 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 941 if (iommu_wq == NULL) 942 goto out; 943 944 amd_iommu_register_ppr_notifier(&ppr_nb); 945 946 pr_info("AMD IOMMUv2 loaded and initialized\n"); 947 948 return 0; 949 950 out: 951 return ret; 952 } 953 954 static void __exit amd_iommu_v2_exit(void) 955 { 956 struct device_state *dev_state, *next; 957 unsigned long flags; 958 LIST_HEAD(freelist); 959 960 if (!amd_iommu_v2_supported()) 961 return; 962 963 amd_iommu_unregister_ppr_notifier(&ppr_nb); 964 965 flush_workqueue(iommu_wq); 966 967 /* 968 * The loop below might call flush_workqueue(), so call 969 * destroy_workqueue() after it 970 */ 971 spin_lock_irqsave(&state_lock, flags); 972 973 list_for_each_entry_safe(dev_state, next, &state_list, list) { 974 WARN_ON_ONCE(1); 975 976 put_device_state(dev_state); 977 list_del(&dev_state->list); 978 list_add_tail(&dev_state->list, &freelist); 979 } 980 981 spin_unlock_irqrestore(&state_lock, flags); 982 983 /* 984 * Since free_device_state waits on the count to be zero, 985 * we need to free dev_state outside the spinlock. 986 */ 987 list_for_each_entry_safe(dev_state, next, &freelist, list) { 988 list_del(&dev_state->list); 989 free_device_state(dev_state); 990 } 991 992 destroy_workqueue(iommu_wq); 993 } 994 995 module_init(amd_iommu_v2_init); 996 module_exit(amd_iommu_v2_exit); 997