diff options
Diffstat (limited to 'drivers/iommu/intel-svm.c')
-rw-r--r-- | drivers/iommu/intel-svm.c | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/drivers/iommu/intel-svm.c b/drivers/iommu/intel-svm.c index 89d4d47d0ab3..817be769e94f 100644 --- a/drivers/iommu/intel-svm.c +++ b/drivers/iommu/intel-svm.c @@ -269,11 +269,10 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_ struct intel_iommu *iommu = intel_svm_device_to_iommu(dev); struct intel_svm_dev *sdev; struct intel_svm *svm = NULL; + struct mm_struct *mm = NULL; int pasid_max; int ret; - BUG_ON(pasid && !current->mm); - if (WARN_ON(!iommu)) return -EINVAL; @@ -284,12 +283,20 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_ } else pasid_max = 1 << 20; + if ((flags & SVM_FLAG_SUPERVISOR_MODE)) { + if (!ecap_srs(iommu->ecap)) + return -EINVAL; + } else if (pasid) { + mm = get_task_mm(current); + BUG_ON(!mm); + } + mutex_lock(&pasid_mutex); if (pasid && !(flags & SVM_FLAG_PRIVATE_PASID)) { int i; idr_for_each_entry(&iommu->pasid_idr, svm, i) { - if (svm->mm != current->mm || + if (svm->mm != mm || (svm->flags & SVM_FLAG_PRIVATE_PASID)) continue; @@ -355,17 +362,22 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_ } svm->pasid = ret; svm->notifier.ops = &intel_mmuops; - svm->mm = get_task_mm(current); + svm->mm = mm; svm->flags = flags; INIT_LIST_HEAD_RCU(&svm->devs); ret = -ENOMEM; - if (!svm->mm || (ret = mmu_notifier_register(&svm->notifier, svm->mm))) { - idr_remove(&svm->iommu->pasid_idr, svm->pasid); - kfree(svm); - kfree(sdev); - goto out; - } - iommu->pasid_table[svm->pasid].val = (u64)__pa(svm->mm->pgd) | 1; + if (mm) { + ret = mmu_notifier_register(&svm->notifier, mm); + if (ret) { + idr_remove(&svm->iommu->pasid_idr, svm->pasid); + kfree(svm); + kfree(sdev); + goto out; + } + iommu->pasid_table[svm->pasid].val = (u64)__pa(mm->pgd) | 1; + mm = NULL; + } else + iommu->pasid_table[svm->pasid].val = (u64)__pa(init_mm.pgd) | 1 | (1ULL << 11); wmb(); } list_add_rcu(&sdev->list, &svm->devs); @@ -375,6 +387,8 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_ ret = 0; out: mutex_unlock(&pasid_mutex); + if (mm) + mmput(mm); return ret; } EXPORT_SYMBOL_GPL(intel_svm_bind_mm); @@ -416,7 +430,8 @@ int intel_svm_unbind_mm(struct device *dev, int pasid) mmu_notifier_unregister(&svm->notifier, svm->mm); idr_remove(&svm->iommu->pasid_idr, svm->pasid); - mmput(svm->mm); + if (svm->mm) + mmput(svm->mm); /* We mandate that no page faults may be outstanding * for the PASID when intel_svm_unbind_mm() is called. * If that is not obeyed, subtle errors will happen. @@ -500,6 +515,10 @@ static irqreturn_t prq_event_thread(int irq, void *d) } result = QI_RESP_INVALID; + /* Since we're using init_mm.pgd directly, we should never take + * any faults on kernel addresses. */ + if (!svm->mm) + goto bad_req; down_read(&svm->mm->mmap_sem); vma = find_extend_vma(svm->mm, address); if (!vma || address < vma->vm_start) |