summaryrefslogtreecommitdiff
path: root/net/vmw_vsock
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock')
-rw-r--r--net/vmw_vsock/af_vsock.c61
-rw-r--r--net/vmw_vsock/hyperv_transport.c35
-rw-r--r--net/vmw_vsock/virtio_transport.c7
-rw-r--r--net/vmw_vsock/virtio_transport_common.c21
-rw-r--r--net/vmw_vsock/vsock_bpf.c2
5 files changed, 89 insertions, 37 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 2f7d94d682cb..44037b066a5f 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -545,9 +545,13 @@ static void vsock_deassign_transport(struct vsock_sock *vsk)
* The vsk->remote_addr is used to decide which transport to use:
* - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
* g2h is not loaded, will use local transport;
- * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field
- * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport;
- * - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ * - remote CID <= VMADDR_CID_HOST or remote flags field includes
+ * VMADDR_FLAG_TO_HOST, will use guest->host transport;
+ * - remote CID > VMADDR_CID_HOST and h2g is loaded and h2g claims that CID,
+ * will use host->guest transport;
+ * - h2g not loaded or h2g does not claim that CID and g2h claims the CID via
+ * has_remote_cid, will use guest->host transport (when g2h_fallback=1)
+ * - anything else goes to h2g or returns -ENODEV if no h2g is available
*/
int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
{
@@ -581,11 +585,21 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
case SOCK_SEQPACKET:
if (vsock_use_local_transport(remote_cid))
new_transport = transport_local;
- else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g ||
+ else if (remote_cid <= VMADDR_CID_HOST ||
(remote_flags & VMADDR_FLAG_TO_HOST))
new_transport = transport_g2h;
- else
+ else if (transport_h2g &&
+ (!transport_h2g->has_remote_cid ||
+ transport_h2g->has_remote_cid(vsk, remote_cid)))
+ new_transport = transport_h2g;
+ else if (sock_net(sk)->vsock.g2h_fallback &&
+ transport_g2h && transport_g2h->has_remote_cid &&
+ transport_g2h->has_remote_cid(vsk, remote_cid)) {
+ vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST;
+ new_transport = transport_g2h;
+ } else {
new_transport = transport_h2g;
+ }
break;
default:
ret = -ESOCKTNOSUPPORT;
@@ -1502,7 +1516,7 @@ int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
prot = READ_ONCE(sk->sk_prot);
if (prot != &vsock_proto)
- return prot->recvmsg(sk, msg, len, flags, NULL);
+ return prot->recvmsg(sk, msg, len, flags);
#endif
return __vsock_dgram_recvmsg(sock, msg, len, flags);
@@ -1850,10 +1864,10 @@ static int vsock_accept(struct socket *sock, struct socket *newsock,
* created upon connection establishment.
*/
timeout = sock_rcvtimeo(listener, arg->flags & O_NONBLOCK);
- prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
while ((connected = vsock_dequeue_accept(listener)) == NULL &&
- listener->sk_err == 0) {
+ listener->sk_err == 0 && timeout != 0) {
+ prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
release_sock(listener);
timeout = schedule_timeout(timeout);
finish_wait(sk_sleep(listener), &wait);
@@ -1862,17 +1876,14 @@ static int vsock_accept(struct socket *sock, struct socket *newsock,
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
goto out;
- } else if (timeout == 0) {
- err = -EAGAIN;
- goto out;
}
-
- prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
}
- finish_wait(sk_sleep(listener), &wait);
- if (listener->sk_err)
+ if (listener->sk_err) {
err = -listener->sk_err;
+ } else if (!connected) {
+ err = -EAGAIN;
+ }
if (connected) {
sk_acceptq_removed(listener);
@@ -1951,12 +1962,12 @@ static void vsock_update_buffer_size(struct vsock_sock *vsk,
const struct vsock_transport *transport,
u64 val)
{
- if (val > vsk->buffer_max_size)
- val = vsk->buffer_max_size;
-
if (val < vsk->buffer_min_size)
val = vsk->buffer_min_size;
+ if (val > vsk->buffer_max_size)
+ val = vsk->buffer_max_size;
+
if (val != vsk->buffer_size &&
transport && transport->notify_buffer_size)
transport->notify_buffer_size(vsk, &val);
@@ -2575,7 +2586,7 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
prot = READ_ONCE(sk->sk_prot);
if (prot != &vsock_proto)
- return prot->recvmsg(sk, msg, len, flags, NULL);
+ return prot->recvmsg(sk, msg, len, flags);
#endif
return __vsock_connectible_recvmsg(sock, msg, len, flags);
@@ -2879,6 +2890,15 @@ static struct ctl_table vsock_table[] = {
.mode = 0644,
.proc_handler = vsock_net_child_mode_string
},
+ {
+ .procname = "g2h_fallback",
+ .data = &init_net.vsock.g2h_fallback,
+ .maxlen = sizeof(int),
+ .mode = 0644,
+ .proc_handler = proc_dointvec_minmax,
+ .extra1 = SYSCTL_ZERO,
+ .extra2 = SYSCTL_ONE,
+ },
};
static int __net_init vsock_sysctl_register(struct net *net)
@@ -2894,6 +2914,7 @@ static int __net_init vsock_sysctl_register(struct net *net)
table[0].data = &net->vsock.mode;
table[1].data = &net->vsock.child_ns_mode;
+ table[2].data = &net->vsock.g2h_fallback;
}
net->vsock.sysctl_hdr = register_net_sysctl_sz(net, "net/vsock", table,
@@ -2928,6 +2949,8 @@ static void vsock_net_init(struct net *net)
net->vsock.mode = vsock_net_child_mode(current->nsproxy->net_ns);
net->vsock.child_ns_mode = net->vsock.mode;
+ net->vsock.child_ns_mode_locked = 0;
+ net->vsock.g2h_fallback = 1;
}
static __net_init int vsock_sysctl_init_net(struct net *net)
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index 069386a74557..7a8963595bf9 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -196,7 +196,7 @@ static int hvs_channel_readable_payload(struct vmbus_channel *chan)
if (readable > HVS_PKT_LEN(0)) {
/* At least we have 1 byte to read. We don't need to return
- * the exact readable bytes: see vsock_stream_recvmsg() ->
+ * the exact readable bytes: see vsock_connectible_recvmsg() ->
* vsock_stream_has_data().
*/
return 1;
@@ -375,10 +375,10 @@ static void hvs_open_connection(struct vmbus_channel *chan)
} else {
sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
- sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE);
+ sndbuf = VMBUS_RING_SIZE(sndbuf);
rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
- rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE);
+ rcvbuf = VMBUS_RING_SIZE(rcvbuf);
}
chan->max_pkt_size = HVS_MAX_PKT_SIZE;
@@ -694,7 +694,6 @@ out:
static s64 hvs_stream_has_data(struct vsock_sock *vsk)
{
struct hvsock *hvs = vsk->trans;
- bool need_refill;
s64 ret;
if (hvs->recv_data_len > 0)
@@ -702,9 +701,31 @@ static s64 hvs_stream_has_data(struct vsock_sock *vsk)
switch (hvs_channel_readable_payload(hvs->chan)) {
case 1:
- need_refill = !hvs->recv_desc;
- if (!need_refill)
- return -EIO;
+ if (hvs->recv_desc) {
+ /* Here hvs->recv_data_len is 0, so hvs->recv_desc must
+ * be NULL unless it points to the 0-byte-payload FIN
+ * packet or a malformed/short packet: see
+ * hvs_update_recv_data().
+ *
+ * If hvs->recv_desc points to the FIN packet, here all
+ * the payload has been dequeued and the peer_shutdown
+ * flag is set, but hvs_channel_readable_payload() still
+ * returns 1, because the VMBus ringbuffer's read_index
+ * is not updated for the FIN packet:
+ * hvs_stream_dequeue() -> hv_pkt_iter_next() updates
+ * the cached priv_read_index but has no opportunity to
+ * update the read_index in hv_pkt_iter_close() as
+ * hvs_stream_has_data() returns 0 for the FIN packet,
+ * so it won't get dequeued.
+ *
+ * In case hvs->recv_desc points to a malformed/short
+ * packet, return -EIO.
+ */
+ if (!(vsk->peer_shutdown & SEND_SHUTDOWN))
+ return -EIO;
+
+ return 0;
+ }
hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
if (!hvs->recv_desc)
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 77fe5b7b066c..57f2d6ec3ffc 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -547,11 +547,18 @@ bool virtio_transport_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port)
static bool virtio_transport_seqpacket_allow(struct vsock_sock *vsk,
u32 remote_cid);
+static bool virtio_transport_has_remote_cid(struct vsock_sock *vsk, u32 cid)
+{
+ /* The CID could be implemented by the host. Always assume it is. */
+ return true;
+}
+
static struct virtio_transport virtio_transport = {
.transport = {
.module = THIS_MODULE,
.get_local_cid = virtio_transport_get_local_cid,
+ .has_remote_cid = virtio_transport_has_remote_cid,
.init = virtio_transport_do_socket_init,
.destruct = virtio_transport_destruct,
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 8a9fb23c6e85..416d533f493d 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -60,8 +60,6 @@ static bool virtio_transport_can_zcopy(const struct virtio_transport *t_ops,
return false;
/* Check that transport can send data in zerocopy mode. */
- t_ops = virtio_transport_get_ops(info->vsk);
-
if (t_ops->can_msgzerocopy) {
int pages_to_send = iov_iter_npages(iov_iter, MAX_SKB_FRAGS);
@@ -75,6 +73,7 @@ static bool virtio_transport_can_zcopy(const struct virtio_transport *t_ops,
static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk,
struct sk_buff *skb,
struct msghdr *msg,
+ size_t pkt_len,
bool zerocopy)
{
struct ubuf_info *uarg;
@@ -83,12 +82,10 @@ static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk,
uarg = msg->msg_ubuf;
net_zcopy_get(uarg);
} else {
- struct iov_iter *iter = &msg->msg_iter;
struct ubuf_info_msgzc *uarg_zc;
uarg = msg_zerocopy_realloc(sk_vsock(vsk),
- iter->count,
- NULL, false);
+ pkt_len, NULL, false);
if (!uarg)
return -1;
@@ -400,11 +397,17 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
* each iteration. If this is last skb for this buffer
* and MSG_ZEROCOPY mode is in use - we must allocate
* completion for the current syscall.
+ *
+ * Pass pkt_len because msg iter is already consumed
+ * by virtio_transport_fill_skb(), so iter->count
+ * can not be used for RLIMIT_MEMLOCK pinned-pages
+ * accounting done by msg_zerocopy_realloc().
*/
if (info->msg && info->msg->msg_flags & MSG_ZEROCOPY &&
skb_len == rest_len && info->op == VIRTIO_VSOCK_OP_RW) {
if (virtio_transport_init_zcopy_skb(vsk, skb,
info->msg,
+ pkt_len,
can_zcopy)) {
kfree_skb(skb);
ret = -ENOMEM;
@@ -547,9 +550,8 @@ virtio_transport_stream_do_peek(struct vsock_sock *vsk,
skb_queue_walk(&vvs->rx_queue, skb) {
size_t bytes;
- bytes = len - total;
- if (bytes > skb->len)
- bytes = skb->len;
+ bytes = min_t(size_t, len - total,
+ skb->len - VIRTIO_VSOCK_SKB_CB(skb)->offset);
spin_unlock_bh(&vvs->rx_lock);
@@ -1560,8 +1562,6 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
return -ENOMEM;
}
- sk_acceptq_added(sk);
-
lock_sock_nested(child, SINGLE_DEPTH_NESTING);
child->sk_state = TCP_ESTABLISHED;
@@ -1583,6 +1583,7 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
return ret;
}
+ sk_acceptq_added(sk);
if (virtio_transport_space_update(child, skb))
child->sk_write_space(child);
diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c
index 07b96d56f3a5..9049d2648646 100644
--- a/net/vmw_vsock/vsock_bpf.c
+++ b/net/vmw_vsock/vsock_bpf.c
@@ -74,7 +74,7 @@ static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int
}
static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
- size_t len, int flags, int *addr_len)
+ size_t len, int flags)
{
struct sk_psock *psock;
struct vsock_sock *vsk;