summaryrefslogtreecommitdiff
path: root/drivers/iommu/intel-svm.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu/intel-svm.c')
-rw-r--r--drivers/iommu/intel-svm.c43
1 files changed, 31 insertions, 12 deletions
diff --git a/drivers/iommu/intel-svm.c b/drivers/iommu/intel-svm.c
index 89d4d47d0ab3..817be769e94f 100644
--- a/drivers/iommu/intel-svm.c
+++ b/drivers/iommu/intel-svm.c
@@ -269,11 +269,10 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
struct intel_iommu *iommu = intel_svm_device_to_iommu(dev);
struct intel_svm_dev *sdev;
struct intel_svm *svm = NULL;
+ struct mm_struct *mm = NULL;
int pasid_max;
int ret;
- BUG_ON(pasid && !current->mm);
-
if (WARN_ON(!iommu))
return -EINVAL;
@@ -284,12 +283,20 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
} else
pasid_max = 1 << 20;
+ if ((flags & SVM_FLAG_SUPERVISOR_MODE)) {
+ if (!ecap_srs(iommu->ecap))
+ return -EINVAL;
+ } else if (pasid) {
+ mm = get_task_mm(current);
+ BUG_ON(!mm);
+ }
+
mutex_lock(&pasid_mutex);
if (pasid && !(flags & SVM_FLAG_PRIVATE_PASID)) {
int i;
idr_for_each_entry(&iommu->pasid_idr, svm, i) {
- if (svm->mm != current->mm ||
+ if (svm->mm != mm ||
(svm->flags & SVM_FLAG_PRIVATE_PASID))
continue;
@@ -355,17 +362,22 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
}
svm->pasid = ret;
svm->notifier.ops = &intel_mmuops;
- svm->mm = get_task_mm(current);
+ svm->mm = mm;
svm->flags = flags;
INIT_LIST_HEAD_RCU(&svm->devs);
ret = -ENOMEM;
- if (!svm->mm || (ret = mmu_notifier_register(&svm->notifier, svm->mm))) {
- idr_remove(&svm->iommu->pasid_idr, svm->pasid);
- kfree(svm);
- kfree(sdev);
- goto out;
- }
- iommu->pasid_table[svm->pasid].val = (u64)__pa(svm->mm->pgd) | 1;
+ if (mm) {
+ ret = mmu_notifier_register(&svm->notifier, mm);
+ if (ret) {
+ idr_remove(&svm->iommu->pasid_idr, svm->pasid);
+ kfree(svm);
+ kfree(sdev);
+ goto out;
+ }
+ iommu->pasid_table[svm->pasid].val = (u64)__pa(mm->pgd) | 1;
+ mm = NULL;
+ } else
+ iommu->pasid_table[svm->pasid].val = (u64)__pa(init_mm.pgd) | 1 | (1ULL << 11);
wmb();
}
list_add_rcu(&sdev->list, &svm->devs);
@@ -375,6 +387,8 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
ret = 0;
out:
mutex_unlock(&pasid_mutex);
+ if (mm)
+ mmput(mm);
return ret;
}
EXPORT_SYMBOL_GPL(intel_svm_bind_mm);
@@ -416,7 +430,8 @@ int intel_svm_unbind_mm(struct device *dev, int pasid)
mmu_notifier_unregister(&svm->notifier, svm->mm);
idr_remove(&svm->iommu->pasid_idr, svm->pasid);
- mmput(svm->mm);
+ if (svm->mm)
+ mmput(svm->mm);
/* We mandate that no page faults may be outstanding
* for the PASID when intel_svm_unbind_mm() is called.
* If that is not obeyed, subtle errors will happen.
@@ -500,6 +515,10 @@ static irqreturn_t prq_event_thread(int irq, void *d)
}
result = QI_RESP_INVALID;
+ /* Since we're using init_mm.pgd directly, we should never take
+ * any faults on kernel addresses. */
+ if (!svm->mm)
+ goto bad_req;
down_read(&svm->mm->mmap_sem);
vma = find_extend_vma(svm->mm, address);
if (!vma || address < vma->vm_start)