1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7 #define pr_fmt(fmt) "AMD-Vi: " fmt 8 9 #include <linux/refcount.h> 10 #include <linux/mmu_notifier.h> 11 #include <linux/amd-iommu.h> 12 #include <linux/mm_types.h> 13 #include <linux/profile.h> 14 #include <linux/module.h> 15 #include <linux/sched.h> 16 #include <linux/sched/mm.h> 17 #include <linux/wait.h> 18 #include <linux/pci.h> 19 #include <linux/gfp.h> 20 21 #include "amd_iommu.h" 22 23 MODULE_LICENSE("GPL v2"); 24 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 25 26 #define MAX_DEVICES 0x10000 27 #define PRI_QUEUE_SIZE 512 28 29 struct pri_queue { 30 atomic_t inflight; 31 bool finish; 32 int status; 33 }; 34 35 struct pasid_state { 36 struct list_head list; /* For global state-list */ 37 refcount_t count; /* Reference count */ 38 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 39 calls */ 40 struct mm_struct *mm; /* mm_struct for the faults */ 41 struct mmu_notifier mn; /* mmu_notifier handle */ 42 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 43 struct device_state *device_state; /* Link to our device_state */ 44 u32 pasid; /* PASID index */ 45 bool invalid; /* Used during setup and 46 teardown of the pasid */ 47 spinlock_t lock; /* Protect pri_queues and 48 mmu_notifer_count */ 49 wait_queue_head_t wq; /* To wait for count == 0 */ 50 }; 51 52 struct device_state { 53 struct list_head list; 54 u16 devid; 55 atomic_t count; 56 struct pci_dev *pdev; 57 struct pasid_state **states; 58 struct iommu_domain *domain; 59 int pasid_levels; 60 int max_pasids; 61 amd_iommu_invalid_ppr_cb inv_ppr_cb; 62 amd_iommu_invalidate_ctx inv_ctx_cb; 63 spinlock_t lock; 64 wait_queue_head_t wq; 65 }; 66 67 struct fault { 68 struct work_struct work; 69 struct device_state *dev_state; 70 struct pasid_state *state; 71 struct mm_struct *mm; 72 u64 address; 73 u16 devid; 74 u32 pasid; 75 u16 tag; 76 u16 finish; 77 u16 flags; 78 }; 79 80 static LIST_HEAD(state_list); 81 static DEFINE_SPINLOCK(state_lock); 82 83 static struct workqueue_struct *iommu_wq; 84 85 static void free_pasid_states(struct device_state *dev_state); 86 87 static u16 device_id(struct pci_dev *pdev) 88 { 89 u16 devid; 90 91 devid = pdev->bus->number; 92 devid = (devid << 8) | pdev->devfn; 93 94 return devid; 95 } 96 97 static struct device_state *__get_device_state(u16 devid) 98 { 99 struct device_state *dev_state; 100 101 list_for_each_entry(dev_state, &state_list, list) { 102 if (dev_state->devid == devid) 103 return dev_state; 104 } 105 106 return NULL; 107 } 108 109 static struct device_state *get_device_state(u16 devid) 110 { 111 struct device_state *dev_state; 112 unsigned long flags; 113 114 spin_lock_irqsave(&state_lock, flags); 115 dev_state = __get_device_state(devid); 116 if (dev_state != NULL) 117 atomic_inc(&dev_state->count); 118 spin_unlock_irqrestore(&state_lock, flags); 119 120 return dev_state; 121 } 122 123 static void free_device_state(struct device_state *dev_state) 124 { 125 struct iommu_group *group; 126 127 /* 128 * First detach device from domain - No more PRI requests will arrive 129 * from that device after it is unbound from the IOMMUv2 domain. 130 */ 131 group = iommu_group_get(&dev_state->pdev->dev); 132 if (WARN_ON(!group)) 133 return; 134 135 iommu_detach_group(dev_state->domain, group); 136 137 iommu_group_put(group); 138 139 /* Everything is down now, free the IOMMUv2 domain */ 140 iommu_domain_free(dev_state->domain); 141 142 /* Finally get rid of the device-state */ 143 kfree(dev_state); 144 } 145 146 static void put_device_state(struct device_state *dev_state) 147 { 148 if (atomic_dec_and_test(&dev_state->count)) 149 wake_up(&dev_state->wq); 150 } 151 152 /* Must be called under dev_state->lock */ 153 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 154 u32 pasid, bool alloc) 155 { 156 struct pasid_state **root, **ptr; 157 int level, index; 158 159 level = dev_state->pasid_levels; 160 root = dev_state->states; 161 162 while (true) { 163 164 index = (pasid >> (9 * level)) & 0x1ff; 165 ptr = &root[index]; 166 167 if (level == 0) 168 break; 169 170 if (*ptr == NULL) { 171 if (!alloc) 172 return NULL; 173 174 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 175 if (*ptr == NULL) 176 return NULL; 177 } 178 179 root = (struct pasid_state **)*ptr; 180 level -= 1; 181 } 182 183 return ptr; 184 } 185 186 static int set_pasid_state(struct device_state *dev_state, 187 struct pasid_state *pasid_state, 188 u32 pasid) 189 { 190 struct pasid_state **ptr; 191 unsigned long flags; 192 int ret; 193 194 spin_lock_irqsave(&dev_state->lock, flags); 195 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 196 197 ret = -ENOMEM; 198 if (ptr == NULL) 199 goto out_unlock; 200 201 ret = -ENOMEM; 202 if (*ptr != NULL) 203 goto out_unlock; 204 205 *ptr = pasid_state; 206 207 ret = 0; 208 209 out_unlock: 210 spin_unlock_irqrestore(&dev_state->lock, flags); 211 212 return ret; 213 } 214 215 static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 216 { 217 struct pasid_state **ptr; 218 unsigned long flags; 219 220 spin_lock_irqsave(&dev_state->lock, flags); 221 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 222 223 if (ptr == NULL) 224 goto out_unlock; 225 226 *ptr = NULL; 227 228 out_unlock: 229 spin_unlock_irqrestore(&dev_state->lock, flags); 230 } 231 232 static struct pasid_state *get_pasid_state(struct device_state *dev_state, 233 u32 pasid) 234 { 235 struct pasid_state **ptr, *ret = NULL; 236 unsigned long flags; 237 238 spin_lock_irqsave(&dev_state->lock, flags); 239 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 240 241 if (ptr == NULL) 242 goto out_unlock; 243 244 ret = *ptr; 245 if (ret) 246 refcount_inc(&ret->count); 247 248 out_unlock: 249 spin_unlock_irqrestore(&dev_state->lock, flags); 250 251 return ret; 252 } 253 254 static void free_pasid_state(struct pasid_state *pasid_state) 255 { 256 kfree(pasid_state); 257 } 258 259 static void put_pasid_state(struct pasid_state *pasid_state) 260 { 261 if (refcount_dec_and_test(&pasid_state->count)) 262 wake_up(&pasid_state->wq); 263 } 264 265 static void put_pasid_state_wait(struct pasid_state *pasid_state) 266 { 267 refcount_dec(&pasid_state->count); 268 wait_event(pasid_state->wq, !refcount_read(&pasid_state->count)); 269 free_pasid_state(pasid_state); 270 } 271 272 static void unbind_pasid(struct pasid_state *pasid_state) 273 { 274 struct iommu_domain *domain; 275 276 domain = pasid_state->device_state->domain; 277 278 /* 279 * Mark pasid_state as invalid, no more faults will we added to the 280 * work queue after this is visible everywhere. 281 */ 282 pasid_state->invalid = true; 283 284 /* Make sure this is visible */ 285 smp_wmb(); 286 287 /* After this the device/pasid can't access the mm anymore */ 288 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 289 290 /* Make sure no more pending faults are in the queue */ 291 flush_workqueue(iommu_wq); 292 } 293 294 static void free_pasid_states_level1(struct pasid_state **tbl) 295 { 296 int i; 297 298 for (i = 0; i < 512; ++i) { 299 if (tbl[i] == NULL) 300 continue; 301 302 free_page((unsigned long)tbl[i]); 303 } 304 } 305 306 static void free_pasid_states_level2(struct pasid_state **tbl) 307 { 308 struct pasid_state **ptr; 309 int i; 310 311 for (i = 0; i < 512; ++i) { 312 if (tbl[i] == NULL) 313 continue; 314 315 ptr = (struct pasid_state **)tbl[i]; 316 free_pasid_states_level1(ptr); 317 } 318 } 319 320 static void free_pasid_states(struct device_state *dev_state) 321 { 322 struct pasid_state *pasid_state; 323 int i; 324 325 for (i = 0; i < dev_state->max_pasids; ++i) { 326 pasid_state = get_pasid_state(dev_state, i); 327 if (pasid_state == NULL) 328 continue; 329 330 put_pasid_state(pasid_state); 331 332 /* 333 * This will call the mn_release function and 334 * unbind the PASID 335 */ 336 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 337 338 put_pasid_state_wait(pasid_state); /* Reference taken in 339 amd_iommu_bind_pasid */ 340 341 /* Drop reference taken in amd_iommu_bind_pasid */ 342 put_device_state(dev_state); 343 } 344 345 if (dev_state->pasid_levels == 2) 346 free_pasid_states_level2(dev_state->states); 347 else if (dev_state->pasid_levels == 1) 348 free_pasid_states_level1(dev_state->states); 349 else 350 BUG_ON(dev_state->pasid_levels != 0); 351 352 free_page((unsigned long)dev_state->states); 353 } 354 355 static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 356 { 357 return container_of(mn, struct pasid_state, mn); 358 } 359 360 static void mn_invalidate_range(struct mmu_notifier *mn, 361 struct mm_struct *mm, 362 unsigned long start, unsigned long end) 363 { 364 struct pasid_state *pasid_state; 365 struct device_state *dev_state; 366 367 pasid_state = mn_to_state(mn); 368 dev_state = pasid_state->device_state; 369 370 if ((start ^ (end - 1)) < PAGE_SIZE) 371 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 372 start); 373 else 374 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 375 } 376 377 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 378 { 379 struct pasid_state *pasid_state; 380 struct device_state *dev_state; 381 bool run_inv_ctx_cb; 382 383 might_sleep(); 384 385 pasid_state = mn_to_state(mn); 386 dev_state = pasid_state->device_state; 387 run_inv_ctx_cb = !pasid_state->invalid; 388 389 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 390 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 391 392 unbind_pasid(pasid_state); 393 } 394 395 static const struct mmu_notifier_ops iommu_mn = { 396 .release = mn_release, 397 .invalidate_range = mn_invalidate_range, 398 }; 399 400 static void set_pri_tag_status(struct pasid_state *pasid_state, 401 u16 tag, int status) 402 { 403 unsigned long flags; 404 405 spin_lock_irqsave(&pasid_state->lock, flags); 406 pasid_state->pri[tag].status = status; 407 spin_unlock_irqrestore(&pasid_state->lock, flags); 408 } 409 410 static void finish_pri_tag(struct device_state *dev_state, 411 struct pasid_state *pasid_state, 412 u16 tag) 413 { 414 unsigned long flags; 415 416 spin_lock_irqsave(&pasid_state->lock, flags); 417 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 418 pasid_state->pri[tag].finish) { 419 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 420 pasid_state->pri[tag].status, tag); 421 pasid_state->pri[tag].finish = false; 422 pasid_state->pri[tag].status = PPR_SUCCESS; 423 } 424 spin_unlock_irqrestore(&pasid_state->lock, flags); 425 } 426 427 static void handle_fault_error(struct fault *fault) 428 { 429 int status; 430 431 if (!fault->dev_state->inv_ppr_cb) { 432 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 433 return; 434 } 435 436 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 437 fault->pasid, 438 fault->address, 439 fault->flags); 440 switch (status) { 441 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 442 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 443 break; 444 case AMD_IOMMU_INV_PRI_RSP_INVALID: 445 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 446 break; 447 case AMD_IOMMU_INV_PRI_RSP_FAIL: 448 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 449 break; 450 default: 451 BUG(); 452 } 453 } 454 455 static bool access_error(struct vm_area_struct *vma, struct fault *fault) 456 { 457 unsigned long requested = 0; 458 459 if (fault->flags & PPR_FAULT_EXEC) 460 requested |= VM_EXEC; 461 462 if (fault->flags & PPR_FAULT_READ) 463 requested |= VM_READ; 464 465 if (fault->flags & PPR_FAULT_WRITE) 466 requested |= VM_WRITE; 467 468 return (requested & ~vma->vm_flags) != 0; 469 } 470 471 static void do_fault(struct work_struct *work) 472 { 473 struct fault *fault = container_of(work, struct fault, work); 474 struct vm_area_struct *vma; 475 vm_fault_t ret = VM_FAULT_ERROR; 476 unsigned int flags = 0; 477 struct mm_struct *mm; 478 u64 address; 479 480 mm = fault->state->mm; 481 address = fault->address; 482 483 if (fault->flags & PPR_FAULT_USER) 484 flags |= FAULT_FLAG_USER; 485 if (fault->flags & PPR_FAULT_WRITE) 486 flags |= FAULT_FLAG_WRITE; 487 flags |= FAULT_FLAG_REMOTE; 488 489 mmap_read_lock(mm); 490 vma = find_extend_vma(mm, address); 491 if (!vma || address < vma->vm_start) 492 /* failed to get a vma in the right range */ 493 goto out; 494 495 /* Check if we have the right permissions on the vma */ 496 if (access_error(vma, fault)) 497 goto out; 498 499 ret = handle_mm_fault(vma, address, flags, NULL); 500 out: 501 mmap_read_unlock(mm); 502 503 if (ret & VM_FAULT_ERROR) 504 /* failed to service fault */ 505 handle_fault_error(fault); 506 507 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 508 509 put_pasid_state(fault->state); 510 511 kfree(fault); 512 } 513 514 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 515 { 516 struct amd_iommu_fault *iommu_fault; 517 struct pasid_state *pasid_state; 518 struct device_state *dev_state; 519 struct pci_dev *pdev = NULL; 520 unsigned long flags; 521 struct fault *fault; 522 bool finish; 523 u16 tag, devid; 524 int ret; 525 526 iommu_fault = data; 527 tag = iommu_fault->tag & 0x1ff; 528 finish = (iommu_fault->tag >> 9) & 1; 529 530 devid = iommu_fault->device_id; 531 pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid), 532 devid & 0xff); 533 if (!pdev) 534 return -ENODEV; 535 536 ret = NOTIFY_DONE; 537 538 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 539 if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) { 540 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 541 PPR_INVALID, tag); 542 goto out; 543 } 544 545 dev_state = get_device_state(iommu_fault->device_id); 546 if (dev_state == NULL) 547 goto out; 548 549 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 550 if (pasid_state == NULL || pasid_state->invalid) { 551 /* We know the device but not the PASID -> send INVALID */ 552 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 553 PPR_INVALID, tag); 554 goto out_drop_state; 555 } 556 557 spin_lock_irqsave(&pasid_state->lock, flags); 558 atomic_inc(&pasid_state->pri[tag].inflight); 559 if (finish) 560 pasid_state->pri[tag].finish = true; 561 spin_unlock_irqrestore(&pasid_state->lock, flags); 562 563 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 564 if (fault == NULL) { 565 /* We are OOM - send success and let the device re-fault */ 566 finish_pri_tag(dev_state, pasid_state, tag); 567 goto out_drop_state; 568 } 569 570 fault->dev_state = dev_state; 571 fault->address = iommu_fault->address; 572 fault->state = pasid_state; 573 fault->tag = tag; 574 fault->finish = finish; 575 fault->pasid = iommu_fault->pasid; 576 fault->flags = iommu_fault->flags; 577 INIT_WORK(&fault->work, do_fault); 578 579 queue_work(iommu_wq, &fault->work); 580 581 ret = NOTIFY_OK; 582 583 out_drop_state: 584 585 if (ret != NOTIFY_OK && pasid_state) 586 put_pasid_state(pasid_state); 587 588 put_device_state(dev_state); 589 590 out: 591 return ret; 592 } 593 594 static struct notifier_block ppr_nb = { 595 .notifier_call = ppr_notifier, 596 }; 597 598 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 599 struct task_struct *task) 600 { 601 struct pasid_state *pasid_state; 602 struct device_state *dev_state; 603 struct mm_struct *mm; 604 u16 devid; 605 int ret; 606 607 might_sleep(); 608 609 if (!amd_iommu_v2_supported()) 610 return -ENODEV; 611 612 devid = device_id(pdev); 613 dev_state = get_device_state(devid); 614 615 if (dev_state == NULL) 616 return -EINVAL; 617 618 ret = -EINVAL; 619 if (pasid >= dev_state->max_pasids) 620 goto out; 621 622 ret = -ENOMEM; 623 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 624 if (pasid_state == NULL) 625 goto out; 626 627 628 refcount_set(&pasid_state->count, 1); 629 init_waitqueue_head(&pasid_state->wq); 630 spin_lock_init(&pasid_state->lock); 631 632 mm = get_task_mm(task); 633 pasid_state->mm = mm; 634 pasid_state->device_state = dev_state; 635 pasid_state->pasid = pasid; 636 pasid_state->invalid = true; /* Mark as valid only if we are 637 done with setting up the pasid */ 638 pasid_state->mn.ops = &iommu_mn; 639 640 if (pasid_state->mm == NULL) 641 goto out_free; 642 643 mmu_notifier_register(&pasid_state->mn, mm); 644 645 ret = set_pasid_state(dev_state, pasid_state, pasid); 646 if (ret) 647 goto out_unregister; 648 649 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 650 __pa(pasid_state->mm->pgd)); 651 if (ret) 652 goto out_clear_state; 653 654 /* Now we are ready to handle faults */ 655 pasid_state->invalid = false; 656 657 /* 658 * Drop the reference to the mm_struct here. We rely on the 659 * mmu_notifier release call-back to inform us when the mm 660 * is going away. 661 */ 662 mmput(mm); 663 664 return 0; 665 666 out_clear_state: 667 clear_pasid_state(dev_state, pasid); 668 669 out_unregister: 670 mmu_notifier_unregister(&pasid_state->mn, mm); 671 mmput(mm); 672 673 out_free: 674 free_pasid_state(pasid_state); 675 676 out: 677 put_device_state(dev_state); 678 679 return ret; 680 } 681 EXPORT_SYMBOL(amd_iommu_bind_pasid); 682 683 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 684 { 685 struct pasid_state *pasid_state; 686 struct device_state *dev_state; 687 u16 devid; 688 689 might_sleep(); 690 691 if (!amd_iommu_v2_supported()) 692 return; 693 694 devid = device_id(pdev); 695 dev_state = get_device_state(devid); 696 if (dev_state == NULL) 697 return; 698 699 if (pasid >= dev_state->max_pasids) 700 goto out; 701 702 pasid_state = get_pasid_state(dev_state, pasid); 703 if (pasid_state == NULL) 704 goto out; 705 /* 706 * Drop reference taken here. We are safe because we still hold 707 * the reference taken in the amd_iommu_bind_pasid function. 708 */ 709 put_pasid_state(pasid_state); 710 711 /* Clear the pasid state so that the pasid can be re-used */ 712 clear_pasid_state(dev_state, pasid_state->pasid); 713 714 /* 715 * Call mmu_notifier_unregister to drop our reference 716 * to pasid_state->mm 717 */ 718 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 719 720 put_pasid_state_wait(pasid_state); /* Reference taken in 721 amd_iommu_bind_pasid */ 722 out: 723 /* Drop reference taken in this function */ 724 put_device_state(dev_state); 725 726 /* Drop reference taken in amd_iommu_bind_pasid */ 727 put_device_state(dev_state); 728 } 729 EXPORT_SYMBOL(amd_iommu_unbind_pasid); 730 731 int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 732 { 733 struct device_state *dev_state; 734 struct iommu_group *group; 735 unsigned long flags; 736 int ret, tmp; 737 u16 devid; 738 739 might_sleep(); 740 741 /* 742 * When memory encryption is active the device is likely not in a 743 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 744 */ 745 if (mem_encrypt_active()) 746 return -ENODEV; 747 748 if (!amd_iommu_v2_supported()) 749 return -ENODEV; 750 751 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 752 return -EINVAL; 753 754 devid = device_id(pdev); 755 756 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 757 if (dev_state == NULL) 758 return -ENOMEM; 759 760 spin_lock_init(&dev_state->lock); 761 init_waitqueue_head(&dev_state->wq); 762 dev_state->pdev = pdev; 763 dev_state->devid = devid; 764 765 tmp = pasids; 766 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 767 dev_state->pasid_levels += 1; 768 769 atomic_set(&dev_state->count, 1); 770 dev_state->max_pasids = pasids; 771 772 ret = -ENOMEM; 773 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 774 if (dev_state->states == NULL) 775 goto out_free_dev_state; 776 777 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 778 if (dev_state->domain == NULL) 779 goto out_free_states; 780 781 amd_iommu_domain_direct_map(dev_state->domain); 782 783 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 784 if (ret) 785 goto out_free_domain; 786 787 group = iommu_group_get(&pdev->dev); 788 if (!group) { 789 ret = -EINVAL; 790 goto out_free_domain; 791 } 792 793 ret = iommu_attach_group(dev_state->domain, group); 794 if (ret != 0) 795 goto out_drop_group; 796 797 iommu_group_put(group); 798 799 spin_lock_irqsave(&state_lock, flags); 800 801 if (__get_device_state(devid) != NULL) { 802 spin_unlock_irqrestore(&state_lock, flags); 803 ret = -EBUSY; 804 goto out_free_domain; 805 } 806 807 list_add_tail(&dev_state->list, &state_list); 808 809 spin_unlock_irqrestore(&state_lock, flags); 810 811 return 0; 812 813 out_drop_group: 814 iommu_group_put(group); 815 816 out_free_domain: 817 iommu_domain_free(dev_state->domain); 818 819 out_free_states: 820 free_page((unsigned long)dev_state->states); 821 822 out_free_dev_state: 823 kfree(dev_state); 824 825 return ret; 826 } 827 EXPORT_SYMBOL(amd_iommu_init_device); 828 829 void amd_iommu_free_device(struct pci_dev *pdev) 830 { 831 struct device_state *dev_state; 832 unsigned long flags; 833 u16 devid; 834 835 if (!amd_iommu_v2_supported()) 836 return; 837 838 devid = device_id(pdev); 839 840 spin_lock_irqsave(&state_lock, flags); 841 842 dev_state = __get_device_state(devid); 843 if (dev_state == NULL) { 844 spin_unlock_irqrestore(&state_lock, flags); 845 return; 846 } 847 848 list_del(&dev_state->list); 849 850 spin_unlock_irqrestore(&state_lock, flags); 851 852 /* Get rid of any remaining pasid states */ 853 free_pasid_states(dev_state); 854 855 put_device_state(dev_state); 856 /* 857 * Wait until the last reference is dropped before freeing 858 * the device state. 859 */ 860 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 861 free_device_state(dev_state); 862 } 863 EXPORT_SYMBOL(amd_iommu_free_device); 864 865 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 866 amd_iommu_invalid_ppr_cb cb) 867 { 868 struct device_state *dev_state; 869 unsigned long flags; 870 u16 devid; 871 int ret; 872 873 if (!amd_iommu_v2_supported()) 874 return -ENODEV; 875 876 devid = device_id(pdev); 877 878 spin_lock_irqsave(&state_lock, flags); 879 880 ret = -EINVAL; 881 dev_state = __get_device_state(devid); 882 if (dev_state == NULL) 883 goto out_unlock; 884 885 dev_state->inv_ppr_cb = cb; 886 887 ret = 0; 888 889 out_unlock: 890 spin_unlock_irqrestore(&state_lock, flags); 891 892 return ret; 893 } 894 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 895 896 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 897 amd_iommu_invalidate_ctx cb) 898 { 899 struct device_state *dev_state; 900 unsigned long flags; 901 u16 devid; 902 int ret; 903 904 if (!amd_iommu_v2_supported()) 905 return -ENODEV; 906 907 devid = device_id(pdev); 908 909 spin_lock_irqsave(&state_lock, flags); 910 911 ret = -EINVAL; 912 dev_state = __get_device_state(devid); 913 if (dev_state == NULL) 914 goto out_unlock; 915 916 dev_state->inv_ctx_cb = cb; 917 918 ret = 0; 919 920 out_unlock: 921 spin_unlock_irqrestore(&state_lock, flags); 922 923 return ret; 924 } 925 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 926 927 static int __init amd_iommu_v2_init(void) 928 { 929 int ret; 930 931 pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n"); 932 933 if (!amd_iommu_v2_supported()) { 934 pr_info("AMD IOMMUv2 functionality not available on this system\n"); 935 /* 936 * Load anyway to provide the symbols to other modules 937 * which may use AMD IOMMUv2 optionally. 938 */ 939 return 0; 940 } 941 942 ret = -ENOMEM; 943 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 944 if (iommu_wq == NULL) 945 goto out; 946 947 amd_iommu_register_ppr_notifier(&ppr_nb); 948 949 return 0; 950 951 out: 952 return ret; 953 } 954 955 static void __exit amd_iommu_v2_exit(void) 956 { 957 struct device_state *dev_state; 958 int i; 959 960 if (!amd_iommu_v2_supported()) 961 return; 962 963 amd_iommu_unregister_ppr_notifier(&ppr_nb); 964 965 flush_workqueue(iommu_wq); 966 967 /* 968 * The loop below might call flush_workqueue(), so call 969 * destroy_workqueue() after it 970 */ 971 for (i = 0; i < MAX_DEVICES; ++i) { 972 dev_state = get_device_state(i); 973 974 if (dev_state == NULL) 975 continue; 976 977 WARN_ON_ONCE(1); 978 979 put_device_state(dev_state); 980 amd_iommu_free_device(dev_state->pdev); 981 } 982 983 destroy_workqueue(iommu_wq); 984 } 985 986 module_init(amd_iommu_v2_init); 987 module_exit(amd_iommu_v2_exit); 988