diff options
-rw-r--r-- | drivers/vfio/vfio.c | 62 |
1 files changed, 46 insertions, 16 deletions
diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c index 073788e50e4b..ac7423bfaa7d 100644 --- a/drivers/vfio/vfio.c +++ b/drivers/vfio/vfio.c @@ -704,9 +704,13 @@ EXPORT_SYMBOL_GPL(vfio_del_group_dev); static long vfio_ioctl_check_extension(struct vfio_container *container, unsigned long arg) { - struct vfio_iommu_driver *driver = container->iommu_driver; + struct vfio_iommu_driver *driver; long ret = 0; + down_read(&container->group_lock); + + driver = container->iommu_driver; + switch (arg) { /* No base extensions yet */ default: @@ -736,6 +740,8 @@ static long vfio_ioctl_check_extension(struct vfio_container *container, VFIO_CHECK_EXTENSION, arg); } + up_read(&container->group_lock); + return ret; } @@ -844,9 +850,6 @@ static long vfio_fops_unl_ioctl(struct file *filep, if (!container) return ret; - driver = container->iommu_driver; - data = container->iommu_data; - switch (cmd) { case VFIO_GET_API_VERSION: ret = VFIO_API_VERSION; @@ -858,8 +861,15 @@ static long vfio_fops_unl_ioctl(struct file *filep, ret = vfio_ioctl_set_iommu(container, arg); break; default: + down_read(&container->group_lock); + + driver = container->iommu_driver; + data = container->iommu_data; + if (driver) /* passthrough all unrecognized ioctls */ ret = driver->ops->ioctl(data, cmd, arg); + + up_read(&container->group_lock); } return ret; @@ -910,35 +920,55 @@ static ssize_t vfio_fops_read(struct file *filep, char __user *buf, size_t count, loff_t *ppos) { struct vfio_container *container = filep->private_data; - struct vfio_iommu_driver *driver = container->iommu_driver; + struct vfio_iommu_driver *driver; + ssize_t ret = -EINVAL; - if (unlikely(!driver || !driver->ops->read)) - return -EINVAL; + down_read(&container->group_lock); + + driver = container->iommu_driver; + if (likely(driver && driver->ops->read)) + ret = driver->ops->read(container->iommu_data, + buf, count, ppos); - return driver->ops->read(container->iommu_data, buf, count, ppos); + up_read(&container->group_lock); + + return ret; } static ssize_t vfio_fops_write(struct file *filep, const char __user *buf, size_t count, loff_t *ppos) { struct vfio_container *container = filep->private_data; - struct vfio_iommu_driver *driver = container->iommu_driver; + struct vfio_iommu_driver *driver; + ssize_t ret = -EINVAL; - if (unlikely(!driver || !driver->ops->write)) - return -EINVAL; + down_read(&container->group_lock); + + driver = container->iommu_driver; + if (likely(driver && driver->ops->write)) + ret = driver->ops->write(container->iommu_data, + buf, count, ppos); + + up_read(&container->group_lock); - return driver->ops->write(container->iommu_data, buf, count, ppos); + return ret; } static int vfio_fops_mmap(struct file *filep, struct vm_area_struct *vma) { struct vfio_container *container = filep->private_data; - struct vfio_iommu_driver *driver = container->iommu_driver; + struct vfio_iommu_driver *driver; + int ret = -EINVAL; - if (unlikely(!driver || !driver->ops->mmap)) - return -EINVAL; + down_read(&container->group_lock); - return driver->ops->mmap(container->iommu_data, vma); + driver = container->iommu_driver; + if (likely(driver && driver->ops->mmap)) + ret = driver->ops->mmap(container->iommu_data, vma); + + up_read(&container->group_lock); + + return ret; } static const struct file_operations vfio_fops = { |