1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7 #define pr_fmt(fmt) "AMD-Vi: " fmt 8 9 #include <linux/mmu_notifier.h> 10 #include <linux/amd-iommu.h> 11 #include <linux/mm_types.h> 12 #include <linux/profile.h> 13 #include <linux/module.h> 14 #include <linux/sched.h> 15 #include <linux/sched/mm.h> 16 #include <linux/wait.h> 17 #include <linux/pci.h> 18 #include <linux/gfp.h> 19 20 #include "amd_iommu.h" 21 22 MODULE_LICENSE("GPL v2"); 23 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 24 25 #define MAX_DEVICES 0x10000 26 #define PRI_QUEUE_SIZE 512 27 28 struct pri_queue { 29 atomic_t inflight; 30 bool finish; 31 int status; 32 }; 33 34 struct pasid_state { 35 struct list_head list; /* For global state-list */ 36 atomic_t count; /* Reference count */ 37 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 38 calls */ 39 struct mm_struct *mm; /* mm_struct for the faults */ 40 struct mmu_notifier mn; /* mmu_notifier handle */ 41 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 42 struct device_state *device_state; /* Link to our device_state */ 43 u32 pasid; /* PASID index */ 44 bool invalid; /* Used during setup and 45 teardown of the pasid */ 46 spinlock_t lock; /* Protect pri_queues and 47 mmu_notifer_count */ 48 wait_queue_head_t wq; /* To wait for count == 0 */ 49 }; 50 51 struct device_state { 52 struct list_head list; 53 u16 devid; 54 atomic_t count; 55 struct pci_dev *pdev; 56 struct pasid_state **states; 57 struct iommu_domain *domain; 58 int pasid_levels; 59 int max_pasids; 60 amd_iommu_invalid_ppr_cb inv_ppr_cb; 61 amd_iommu_invalidate_ctx inv_ctx_cb; 62 spinlock_t lock; 63 wait_queue_head_t wq; 64 }; 65 66 struct fault { 67 struct work_struct work; 68 struct device_state *dev_state; 69 struct pasid_state *state; 70 struct mm_struct *mm; 71 u64 address; 72 u16 devid; 73 u32 pasid; 74 u16 tag; 75 u16 finish; 76 u16 flags; 77 }; 78 79 static LIST_HEAD(state_list); 80 static DEFINE_SPINLOCK(state_lock); 81 82 static struct workqueue_struct *iommu_wq; 83 84 static void free_pasid_states(struct device_state *dev_state); 85 86 static u16 device_id(struct pci_dev *pdev) 87 { 88 u16 devid; 89 90 devid = pdev->bus->number; 91 devid = (devid << 8) | pdev->devfn; 92 93 return devid; 94 } 95 96 static struct device_state *__get_device_state(u16 devid) 97 { 98 struct device_state *dev_state; 99 100 list_for_each_entry(dev_state, &state_list, list) { 101 if (dev_state->devid == devid) 102 return dev_state; 103 } 104 105 return NULL; 106 } 107 108 static struct device_state *get_device_state(u16 devid) 109 { 110 struct device_state *dev_state; 111 unsigned long flags; 112 113 spin_lock_irqsave(&state_lock, flags); 114 dev_state = __get_device_state(devid); 115 if (dev_state != NULL) 116 atomic_inc(&dev_state->count); 117 spin_unlock_irqrestore(&state_lock, flags); 118 119 return dev_state; 120 } 121 122 static void free_device_state(struct device_state *dev_state) 123 { 124 struct iommu_group *group; 125 126 /* 127 * First detach device from domain - No more PRI requests will arrive 128 * from that device after it is unbound from the IOMMUv2 domain. 129 */ 130 group = iommu_group_get(&dev_state->pdev->dev); 131 if (WARN_ON(!group)) 132 return; 133 134 iommu_detach_group(dev_state->domain, group); 135 136 iommu_group_put(group); 137 138 /* Everything is down now, free the IOMMUv2 domain */ 139 iommu_domain_free(dev_state->domain); 140 141 /* Finally get rid of the device-state */ 142 kfree(dev_state); 143 } 144 145 static void put_device_state(struct device_state *dev_state) 146 { 147 if (atomic_dec_and_test(&dev_state->count)) 148 wake_up(&dev_state->wq); 149 } 150 151 /* Must be called under dev_state->lock */ 152 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 153 u32 pasid, bool alloc) 154 { 155 struct pasid_state **root, **ptr; 156 int level, index; 157 158 level = dev_state->pasid_levels; 159 root = dev_state->states; 160 161 while (true) { 162 163 index = (pasid >> (9 * level)) & 0x1ff; 164 ptr = &root[index]; 165 166 if (level == 0) 167 break; 168 169 if (*ptr == NULL) { 170 if (!alloc) 171 return NULL; 172 173 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 174 if (*ptr == NULL) 175 return NULL; 176 } 177 178 root = (struct pasid_state **)*ptr; 179 level -= 1; 180 } 181 182 return ptr; 183 } 184 185 static int set_pasid_state(struct device_state *dev_state, 186 struct pasid_state *pasid_state, 187 u32 pasid) 188 { 189 struct pasid_state **ptr; 190 unsigned long flags; 191 int ret; 192 193 spin_lock_irqsave(&dev_state->lock, flags); 194 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 195 196 ret = -ENOMEM; 197 if (ptr == NULL) 198 goto out_unlock; 199 200 ret = -ENOMEM; 201 if (*ptr != NULL) 202 goto out_unlock; 203 204 *ptr = pasid_state; 205 206 ret = 0; 207 208 out_unlock: 209 spin_unlock_irqrestore(&dev_state->lock, flags); 210 211 return ret; 212 } 213 214 static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 215 { 216 struct pasid_state **ptr; 217 unsigned long flags; 218 219 spin_lock_irqsave(&dev_state->lock, flags); 220 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 221 222 if (ptr == NULL) 223 goto out_unlock; 224 225 *ptr = NULL; 226 227 out_unlock: 228 spin_unlock_irqrestore(&dev_state->lock, flags); 229 } 230 231 static struct pasid_state *get_pasid_state(struct device_state *dev_state, 232 u32 pasid) 233 { 234 struct pasid_state **ptr, *ret = NULL; 235 unsigned long flags; 236 237 spin_lock_irqsave(&dev_state->lock, flags); 238 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 239 240 if (ptr == NULL) 241 goto out_unlock; 242 243 ret = *ptr; 244 if (ret) 245 atomic_inc(&ret->count); 246 247 out_unlock: 248 spin_unlock_irqrestore(&dev_state->lock, flags); 249 250 return ret; 251 } 252 253 static void free_pasid_state(struct pasid_state *pasid_state) 254 { 255 kfree(pasid_state); 256 } 257 258 static void put_pasid_state(struct pasid_state *pasid_state) 259 { 260 if (atomic_dec_and_test(&pasid_state->count)) 261 wake_up(&pasid_state->wq); 262 } 263 264 static void put_pasid_state_wait(struct pasid_state *pasid_state) 265 { 266 atomic_dec(&pasid_state->count); 267 wait_event(pasid_state->wq, !atomic_read(&pasid_state->count)); 268 free_pasid_state(pasid_state); 269 } 270 271 static void unbind_pasid(struct pasid_state *pasid_state) 272 { 273 struct iommu_domain *domain; 274 275 domain = pasid_state->device_state->domain; 276 277 /* 278 * Mark pasid_state as invalid, no more faults will we added to the 279 * work queue after this is visible everywhere. 280 */ 281 pasid_state->invalid = true; 282 283 /* Make sure this is visible */ 284 smp_wmb(); 285 286 /* After this the device/pasid can't access the mm anymore */ 287 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 288 289 /* Make sure no more pending faults are in the queue */ 290 flush_workqueue(iommu_wq); 291 } 292 293 static void free_pasid_states_level1(struct pasid_state **tbl) 294 { 295 int i; 296 297 for (i = 0; i < 512; ++i) { 298 if (tbl[i] == NULL) 299 continue; 300 301 free_page((unsigned long)tbl[i]); 302 } 303 } 304 305 static void free_pasid_states_level2(struct pasid_state **tbl) 306 { 307 struct pasid_state **ptr; 308 int i; 309 310 for (i = 0; i < 512; ++i) { 311 if (tbl[i] == NULL) 312 continue; 313 314 ptr = (struct pasid_state **)tbl[i]; 315 free_pasid_states_level1(ptr); 316 } 317 } 318 319 static void free_pasid_states(struct device_state *dev_state) 320 { 321 struct pasid_state *pasid_state; 322 int i; 323 324 for (i = 0; i < dev_state->max_pasids; ++i) { 325 pasid_state = get_pasid_state(dev_state, i); 326 if (pasid_state == NULL) 327 continue; 328 329 put_pasid_state(pasid_state); 330 331 /* 332 * This will call the mn_release function and 333 * unbind the PASID 334 */ 335 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 336 337 put_pasid_state_wait(pasid_state); /* Reference taken in 338 amd_iommu_bind_pasid */ 339 340 /* Drop reference taken in amd_iommu_bind_pasid */ 341 put_device_state(dev_state); 342 } 343 344 if (dev_state->pasid_levels == 2) 345 free_pasid_states_level2(dev_state->states); 346 else if (dev_state->pasid_levels == 1) 347 free_pasid_states_level1(dev_state->states); 348 else 349 BUG_ON(dev_state->pasid_levels != 0); 350 351 free_page((unsigned long)dev_state->states); 352 } 353 354 static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 355 { 356 return container_of(mn, struct pasid_state, mn); 357 } 358 359 static void mn_invalidate_range(struct mmu_notifier *mn, 360 struct mm_struct *mm, 361 unsigned long start, unsigned long end) 362 { 363 struct pasid_state *pasid_state; 364 struct device_state *dev_state; 365 366 pasid_state = mn_to_state(mn); 367 dev_state = pasid_state->device_state; 368 369 if ((start ^ (end - 1)) < PAGE_SIZE) 370 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 371 start); 372 else 373 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 374 } 375 376 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 377 { 378 struct pasid_state *pasid_state; 379 struct device_state *dev_state; 380 bool run_inv_ctx_cb; 381 382 might_sleep(); 383 384 pasid_state = mn_to_state(mn); 385 dev_state = pasid_state->device_state; 386 run_inv_ctx_cb = !pasid_state->invalid; 387 388 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 389 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 390 391 unbind_pasid(pasid_state); 392 } 393 394 static const struct mmu_notifier_ops iommu_mn = { 395 .release = mn_release, 396 .invalidate_range = mn_invalidate_range, 397 }; 398 399 static void set_pri_tag_status(struct pasid_state *pasid_state, 400 u16 tag, int status) 401 { 402 unsigned long flags; 403 404 spin_lock_irqsave(&pasid_state->lock, flags); 405 pasid_state->pri[tag].status = status; 406 spin_unlock_irqrestore(&pasid_state->lock, flags); 407 } 408 409 static void finish_pri_tag(struct device_state *dev_state, 410 struct pasid_state *pasid_state, 411 u16 tag) 412 { 413 unsigned long flags; 414 415 spin_lock_irqsave(&pasid_state->lock, flags); 416 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 417 pasid_state->pri[tag].finish) { 418 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 419 pasid_state->pri[tag].status, tag); 420 pasid_state->pri[tag].finish = false; 421 pasid_state->pri[tag].status = PPR_SUCCESS; 422 } 423 spin_unlock_irqrestore(&pasid_state->lock, flags); 424 } 425 426 static void handle_fault_error(struct fault *fault) 427 { 428 int status; 429 430 if (!fault->dev_state->inv_ppr_cb) { 431 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 432 return; 433 } 434 435 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 436 fault->pasid, 437 fault->address, 438 fault->flags); 439 switch (status) { 440 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 441 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 442 break; 443 case AMD_IOMMU_INV_PRI_RSP_INVALID: 444 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 445 break; 446 case AMD_IOMMU_INV_PRI_RSP_FAIL: 447 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 448 break; 449 default: 450 BUG(); 451 } 452 } 453 454 static bool access_error(struct vm_area_struct *vma, struct fault *fault) 455 { 456 unsigned long requested = 0; 457 458 if (fault->flags & PPR_FAULT_EXEC) 459 requested |= VM_EXEC; 460 461 if (fault->flags & PPR_FAULT_READ) 462 requested |= VM_READ; 463 464 if (fault->flags & PPR_FAULT_WRITE) 465 requested |= VM_WRITE; 466 467 return (requested & ~vma->vm_flags) != 0; 468 } 469 470 static void do_fault(struct work_struct *work) 471 { 472 struct fault *fault = container_of(work, struct fault, work); 473 struct vm_area_struct *vma; 474 vm_fault_t ret = VM_FAULT_ERROR; 475 unsigned int flags = 0; 476 struct mm_struct *mm; 477 u64 address; 478 479 mm = fault->state->mm; 480 address = fault->address; 481 482 if (fault->flags & PPR_FAULT_USER) 483 flags |= FAULT_FLAG_USER; 484 if (fault->flags & PPR_FAULT_WRITE) 485 flags |= FAULT_FLAG_WRITE; 486 flags |= FAULT_FLAG_REMOTE; 487 488 mmap_read_lock(mm); 489 vma = find_extend_vma(mm, address); 490 if (!vma || address < vma->vm_start) 491 /* failed to get a vma in the right range */ 492 goto out; 493 494 /* Check if we have the right permissions on the vma */ 495 if (access_error(vma, fault)) 496 goto out; 497 498 ret = handle_mm_fault(vma, address, flags, NULL); 499 out: 500 mmap_read_unlock(mm); 501 502 if (ret & VM_FAULT_ERROR) 503 /* failed to service fault */ 504 handle_fault_error(fault); 505 506 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 507 508 put_pasid_state(fault->state); 509 510 kfree(fault); 511 } 512 513 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 514 { 515 struct amd_iommu_fault *iommu_fault; 516 struct pasid_state *pasid_state; 517 struct device_state *dev_state; 518 struct pci_dev *pdev = NULL; 519 unsigned long flags; 520 struct fault *fault; 521 bool finish; 522 u16 tag, devid; 523 int ret; 524 525 iommu_fault = data; 526 tag = iommu_fault->tag & 0x1ff; 527 finish = (iommu_fault->tag >> 9) & 1; 528 529 devid = iommu_fault->device_id; 530 pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid), 531 devid & 0xff); 532 if (!pdev) 533 return -ENODEV; 534 535 ret = NOTIFY_DONE; 536 537 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 538 if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) { 539 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 540 PPR_INVALID, tag); 541 goto out; 542 } 543 544 dev_state = get_device_state(iommu_fault->device_id); 545 if (dev_state == NULL) 546 goto out; 547 548 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 549 if (pasid_state == NULL || pasid_state->invalid) { 550 /* We know the device but not the PASID -> send INVALID */ 551 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 552 PPR_INVALID, tag); 553 goto out_drop_state; 554 } 555 556 spin_lock_irqsave(&pasid_state->lock, flags); 557 atomic_inc(&pasid_state->pri[tag].inflight); 558 if (finish) 559 pasid_state->pri[tag].finish = true; 560 spin_unlock_irqrestore(&pasid_state->lock, flags); 561 562 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 563 if (fault == NULL) { 564 /* We are OOM - send success and let the device re-fault */ 565 finish_pri_tag(dev_state, pasid_state, tag); 566 goto out_drop_state; 567 } 568 569 fault->dev_state = dev_state; 570 fault->address = iommu_fault->address; 571 fault->state = pasid_state; 572 fault->tag = tag; 573 fault->finish = finish; 574 fault->pasid = iommu_fault->pasid; 575 fault->flags = iommu_fault->flags; 576 INIT_WORK(&fault->work, do_fault); 577 578 queue_work(iommu_wq, &fault->work); 579 580 ret = NOTIFY_OK; 581 582 out_drop_state: 583 584 if (ret != NOTIFY_OK && pasid_state) 585 put_pasid_state(pasid_state); 586 587 put_device_state(dev_state); 588 589 out: 590 return ret; 591 } 592 593 static struct notifier_block ppr_nb = { 594 .notifier_call = ppr_notifier, 595 }; 596 597 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 598 struct task_struct *task) 599 { 600 struct pasid_state *pasid_state; 601 struct device_state *dev_state; 602 struct mm_struct *mm; 603 u16 devid; 604 int ret; 605 606 might_sleep(); 607 608 if (!amd_iommu_v2_supported()) 609 return -ENODEV; 610 611 devid = device_id(pdev); 612 dev_state = get_device_state(devid); 613 614 if (dev_state == NULL) 615 return -EINVAL; 616 617 ret = -EINVAL; 618 if (pasid >= dev_state->max_pasids) 619 goto out; 620 621 ret = -ENOMEM; 622 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 623 if (pasid_state == NULL) 624 goto out; 625 626 627 atomic_set(&pasid_state->count, 1); 628 init_waitqueue_head(&pasid_state->wq); 629 spin_lock_init(&pasid_state->lock); 630 631 mm = get_task_mm(task); 632 pasid_state->mm = mm; 633 pasid_state->device_state = dev_state; 634 pasid_state->pasid = pasid; 635 pasid_state->invalid = true; /* Mark as valid only if we are 636 done with setting up the pasid */ 637 pasid_state->mn.ops = &iommu_mn; 638 639 if (pasid_state->mm == NULL) 640 goto out_free; 641 642 mmu_notifier_register(&pasid_state->mn, mm); 643 644 ret = set_pasid_state(dev_state, pasid_state, pasid); 645 if (ret) 646 goto out_unregister; 647 648 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 649 __pa(pasid_state->mm->pgd)); 650 if (ret) 651 goto out_clear_state; 652 653 /* Now we are ready to handle faults */ 654 pasid_state->invalid = false; 655 656 /* 657 * Drop the reference to the mm_struct here. We rely on the 658 * mmu_notifier release call-back to inform us when the mm 659 * is going away. 660 */ 661 mmput(mm); 662 663 return 0; 664 665 out_clear_state: 666 clear_pasid_state(dev_state, pasid); 667 668 out_unregister: 669 mmu_notifier_unregister(&pasid_state->mn, mm); 670 mmput(mm); 671 672 out_free: 673 free_pasid_state(pasid_state); 674 675 out: 676 put_device_state(dev_state); 677 678 return ret; 679 } 680 EXPORT_SYMBOL(amd_iommu_bind_pasid); 681 682 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 683 { 684 struct pasid_state *pasid_state; 685 struct device_state *dev_state; 686 u16 devid; 687 688 might_sleep(); 689 690 if (!amd_iommu_v2_supported()) 691 return; 692 693 devid = device_id(pdev); 694 dev_state = get_device_state(devid); 695 if (dev_state == NULL) 696 return; 697 698 if (pasid >= dev_state->max_pasids) 699 goto out; 700 701 pasid_state = get_pasid_state(dev_state, pasid); 702 if (pasid_state == NULL) 703 goto out; 704 /* 705 * Drop reference taken here. We are safe because we still hold 706 * the reference taken in the amd_iommu_bind_pasid function. 707 */ 708 put_pasid_state(pasid_state); 709 710 /* Clear the pasid state so that the pasid can be re-used */ 711 clear_pasid_state(dev_state, pasid_state->pasid); 712 713 /* 714 * Call mmu_notifier_unregister to drop our reference 715 * to pasid_state->mm 716 */ 717 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 718 719 put_pasid_state_wait(pasid_state); /* Reference taken in 720 amd_iommu_bind_pasid */ 721 out: 722 /* Drop reference taken in this function */ 723 put_device_state(dev_state); 724 725 /* Drop reference taken in amd_iommu_bind_pasid */ 726 put_device_state(dev_state); 727 } 728 EXPORT_SYMBOL(amd_iommu_unbind_pasid); 729 730 int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 731 { 732 struct device_state *dev_state; 733 struct iommu_group *group; 734 unsigned long flags; 735 int ret, tmp; 736 u16 devid; 737 738 might_sleep(); 739 740 /* 741 * When memory encryption is active the device is likely not in a 742 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 743 */ 744 if (mem_encrypt_active()) 745 return -ENODEV; 746 747 if (!amd_iommu_v2_supported()) 748 return -ENODEV; 749 750 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 751 return -EINVAL; 752 753 devid = device_id(pdev); 754 755 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 756 if (dev_state == NULL) 757 return -ENOMEM; 758 759 spin_lock_init(&dev_state->lock); 760 init_waitqueue_head(&dev_state->wq); 761 dev_state->pdev = pdev; 762 dev_state->devid = devid; 763 764 tmp = pasids; 765 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 766 dev_state->pasid_levels += 1; 767 768 atomic_set(&dev_state->count, 1); 769 dev_state->max_pasids = pasids; 770 771 ret = -ENOMEM; 772 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 773 if (dev_state->states == NULL) 774 goto out_free_dev_state; 775 776 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 777 if (dev_state->domain == NULL) 778 goto out_free_states; 779 780 amd_iommu_domain_direct_map(dev_state->domain); 781 782 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 783 if (ret) 784 goto out_free_domain; 785 786 group = iommu_group_get(&pdev->dev); 787 if (!group) { 788 ret = -EINVAL; 789 goto out_free_domain; 790 } 791 792 ret = iommu_attach_group(dev_state->domain, group); 793 if (ret != 0) 794 goto out_drop_group; 795 796 iommu_group_put(group); 797 798 spin_lock_irqsave(&state_lock, flags); 799 800 if (__get_device_state(devid) != NULL) { 801 spin_unlock_irqrestore(&state_lock, flags); 802 ret = -EBUSY; 803 goto out_free_domain; 804 } 805 806 list_add_tail(&dev_state->list, &state_list); 807 808 spin_unlock_irqrestore(&state_lock, flags); 809 810 return 0; 811 812 out_drop_group: 813 iommu_group_put(group); 814 815 out_free_domain: 816 iommu_domain_free(dev_state->domain); 817 818 out_free_states: 819 free_page((unsigned long)dev_state->states); 820 821 out_free_dev_state: 822 kfree(dev_state); 823 824 return ret; 825 } 826 EXPORT_SYMBOL(amd_iommu_init_device); 827 828 void amd_iommu_free_device(struct pci_dev *pdev) 829 { 830 struct device_state *dev_state; 831 unsigned long flags; 832 u16 devid; 833 834 if (!amd_iommu_v2_supported()) 835 return; 836 837 devid = device_id(pdev); 838 839 spin_lock_irqsave(&state_lock, flags); 840 841 dev_state = __get_device_state(devid); 842 if (dev_state == NULL) { 843 spin_unlock_irqrestore(&state_lock, flags); 844 return; 845 } 846 847 list_del(&dev_state->list); 848 849 spin_unlock_irqrestore(&state_lock, flags); 850 851 /* Get rid of any remaining pasid states */ 852 free_pasid_states(dev_state); 853 854 put_device_state(dev_state); 855 /* 856 * Wait until the last reference is dropped before freeing 857 * the device state. 858 */ 859 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 860 free_device_state(dev_state); 861 } 862 EXPORT_SYMBOL(amd_iommu_free_device); 863 864 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 865 amd_iommu_invalid_ppr_cb cb) 866 { 867 struct device_state *dev_state; 868 unsigned long flags; 869 u16 devid; 870 int ret; 871 872 if (!amd_iommu_v2_supported()) 873 return -ENODEV; 874 875 devid = device_id(pdev); 876 877 spin_lock_irqsave(&state_lock, flags); 878 879 ret = -EINVAL; 880 dev_state = __get_device_state(devid); 881 if (dev_state == NULL) 882 goto out_unlock; 883 884 dev_state->inv_ppr_cb = cb; 885 886 ret = 0; 887 888 out_unlock: 889 spin_unlock_irqrestore(&state_lock, flags); 890 891 return ret; 892 } 893 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 894 895 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 896 amd_iommu_invalidate_ctx cb) 897 { 898 struct device_state *dev_state; 899 unsigned long flags; 900 u16 devid; 901 int ret; 902 903 if (!amd_iommu_v2_supported()) 904 return -ENODEV; 905 906 devid = device_id(pdev); 907 908 spin_lock_irqsave(&state_lock, flags); 909 910 ret = -EINVAL; 911 dev_state = __get_device_state(devid); 912 if (dev_state == NULL) 913 goto out_unlock; 914 915 dev_state->inv_ctx_cb = cb; 916 917 ret = 0; 918 919 out_unlock: 920 spin_unlock_irqrestore(&state_lock, flags); 921 922 return ret; 923 } 924 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 925 926 static int __init amd_iommu_v2_init(void) 927 { 928 int ret; 929 930 pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n"); 931 932 if (!amd_iommu_v2_supported()) { 933 pr_info("AMD IOMMUv2 functionality not available on this system\n"); 934 /* 935 * Load anyway to provide the symbols to other modules 936 * which may use AMD IOMMUv2 optionally. 937 */ 938 return 0; 939 } 940 941 ret = -ENOMEM; 942 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 943 if (iommu_wq == NULL) 944 goto out; 945 946 amd_iommu_register_ppr_notifier(&ppr_nb); 947 948 return 0; 949 950 out: 951 return ret; 952 } 953 954 static void __exit amd_iommu_v2_exit(void) 955 { 956 struct device_state *dev_state; 957 int i; 958 959 if (!amd_iommu_v2_supported()) 960 return; 961 962 amd_iommu_unregister_ppr_notifier(&ppr_nb); 963 964 flush_workqueue(iommu_wq); 965 966 /* 967 * The loop below might call flush_workqueue(), so call 968 * destroy_workqueue() after it 969 */ 970 for (i = 0; i < MAX_DEVICES; ++i) { 971 dev_state = get_device_state(i); 972 973 if (dev_state == NULL) 974 continue; 975 976 WARN_ON_ONCE(1); 977 978 put_device_state(dev_state); 979 amd_iommu_free_device(dev_state->pdev); 980 } 981 982 destroy_workqueue(iommu_wq); 983 } 984 985 module_init(amd_iommu_v2_init); 986 module_exit(amd_iommu_v2_exit); 987