diff options
Diffstat (limited to 'net/vmw_vsock')
| -rw-r--r-- | net/vmw_vsock/af_vsock.c | 61 | ||||
| -rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 35 | ||||
| -rw-r--r-- | net/vmw_vsock/virtio_transport.c | 7 | ||||
| -rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 21 | ||||
| -rw-r--r-- | net/vmw_vsock/vsock_bpf.c | 2 |
5 files changed, 89 insertions, 37 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 2f7d94d682cb..44037b066a5f 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -545,9 +545,13 @@ static void vsock_deassign_transport(struct vsock_sock *vsk) * The vsk->remote_addr is used to decide which transport to use: * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if * g2h is not loaded, will use local transport; - * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field - * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport; - * - remote CID > VMADDR_CID_HOST will use host->guest transport; + * - remote CID <= VMADDR_CID_HOST or remote flags field includes + * VMADDR_FLAG_TO_HOST, will use guest->host transport; + * - remote CID > VMADDR_CID_HOST and h2g is loaded and h2g claims that CID, + * will use host->guest transport; + * - h2g not loaded or h2g does not claim that CID and g2h claims the CID via + * has_remote_cid, will use guest->host transport (when g2h_fallback=1) + * - anything else goes to h2g or returns -ENODEV if no h2g is available */ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) { @@ -581,11 +585,21 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) case SOCK_SEQPACKET: if (vsock_use_local_transport(remote_cid)) new_transport = transport_local; - else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g || + else if (remote_cid <= VMADDR_CID_HOST || (remote_flags & VMADDR_FLAG_TO_HOST)) new_transport = transport_g2h; - else + else if (transport_h2g && + (!transport_h2g->has_remote_cid || + transport_h2g->has_remote_cid(vsk, remote_cid))) + new_transport = transport_h2g; + else if (sock_net(sk)->vsock.g2h_fallback && + transport_g2h && transport_g2h->has_remote_cid && + transport_g2h->has_remote_cid(vsk, remote_cid)) { + vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST; + new_transport = transport_g2h; + } else { new_transport = transport_h2g; + } break; default: ret = -ESOCKTNOSUPPORT; @@ -1502,7 +1516,7 @@ int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, prot = READ_ONCE(sk->sk_prot); if (prot != &vsock_proto) - return prot->recvmsg(sk, msg, len, flags, NULL); + return prot->recvmsg(sk, msg, len, flags); #endif return __vsock_dgram_recvmsg(sock, msg, len, flags); @@ -1850,10 +1864,10 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, * created upon connection establishment. */ timeout = sock_rcvtimeo(listener, arg->flags & O_NONBLOCK); - prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); while ((connected = vsock_dequeue_accept(listener)) == NULL && - listener->sk_err == 0) { + listener->sk_err == 0 && timeout != 0) { + prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); release_sock(listener); timeout = schedule_timeout(timeout); finish_wait(sk_sleep(listener), &wait); @@ -1862,17 +1876,14 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, if (signal_pending(current)) { err = sock_intr_errno(timeout); goto out; - } else if (timeout == 0) { - err = -EAGAIN; - goto out; } - - prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); } - finish_wait(sk_sleep(listener), &wait); - if (listener->sk_err) + if (listener->sk_err) { err = -listener->sk_err; + } else if (!connected) { + err = -EAGAIN; + } if (connected) { sk_acceptq_removed(listener); @@ -1951,12 +1962,12 @@ static void vsock_update_buffer_size(struct vsock_sock *vsk, const struct vsock_transport *transport, u64 val) { - if (val > vsk->buffer_max_size) - val = vsk->buffer_max_size; - if (val < vsk->buffer_min_size) val = vsk->buffer_min_size; + if (val > vsk->buffer_max_size) + val = vsk->buffer_max_size; + if (val != vsk->buffer_size && transport && transport->notify_buffer_size) transport->notify_buffer_size(vsk, &val); @@ -2575,7 +2586,7 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, prot = READ_ONCE(sk->sk_prot); if (prot != &vsock_proto) - return prot->recvmsg(sk, msg, len, flags, NULL); + return prot->recvmsg(sk, msg, len, flags); #endif return __vsock_connectible_recvmsg(sock, msg, len, flags); @@ -2879,6 +2890,15 @@ static struct ctl_table vsock_table[] = { .mode = 0644, .proc_handler = vsock_net_child_mode_string }, + { + .procname = "g2h_fallback", + .data = &init_net.vsock.g2h_fallback, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, }; static int __net_init vsock_sysctl_register(struct net *net) @@ -2894,6 +2914,7 @@ static int __net_init vsock_sysctl_register(struct net *net) table[0].data = &net->vsock.mode; table[1].data = &net->vsock.child_ns_mode; + table[2].data = &net->vsock.g2h_fallback; } net->vsock.sysctl_hdr = register_net_sysctl_sz(net, "net/vsock", table, @@ -2928,6 +2949,8 @@ static void vsock_net_init(struct net *net) net->vsock.mode = vsock_net_child_mode(current->nsproxy->net_ns); net->vsock.child_ns_mode = net->vsock.mode; + net->vsock.child_ns_mode_locked = 0; + net->vsock.g2h_fallback = 1; } static __net_init int vsock_sysctl_init_net(struct net *net) diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index 069386a74557..7a8963595bf9 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -196,7 +196,7 @@ static int hvs_channel_readable_payload(struct vmbus_channel *chan) if (readable > HVS_PKT_LEN(0)) { /* At least we have 1 byte to read. We don't need to return - * the exact readable bytes: see vsock_stream_recvmsg() -> + * the exact readable bytes: see vsock_connectible_recvmsg() -> * vsock_stream_has_data(). */ return 1; @@ -375,10 +375,10 @@ static void hvs_open_connection(struct vmbus_channel *chan) } else { sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE); sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE); - sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE); + sndbuf = VMBUS_RING_SIZE(sndbuf); rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE); rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE); - rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE); + rcvbuf = VMBUS_RING_SIZE(rcvbuf); } chan->max_pkt_size = HVS_MAX_PKT_SIZE; @@ -694,7 +694,6 @@ out: static s64 hvs_stream_has_data(struct vsock_sock *vsk) { struct hvsock *hvs = vsk->trans; - bool need_refill; s64 ret; if (hvs->recv_data_len > 0) @@ -702,9 +701,31 @@ static s64 hvs_stream_has_data(struct vsock_sock *vsk) switch (hvs_channel_readable_payload(hvs->chan)) { case 1: - need_refill = !hvs->recv_desc; - if (!need_refill) - return -EIO; + if (hvs->recv_desc) { + /* Here hvs->recv_data_len is 0, so hvs->recv_desc must + * be NULL unless it points to the 0-byte-payload FIN + * packet or a malformed/short packet: see + * hvs_update_recv_data(). + * + * If hvs->recv_desc points to the FIN packet, here all + * the payload has been dequeued and the peer_shutdown + * flag is set, but hvs_channel_readable_payload() still + * returns 1, because the VMBus ringbuffer's read_index + * is not updated for the FIN packet: + * hvs_stream_dequeue() -> hv_pkt_iter_next() updates + * the cached priv_read_index but has no opportunity to + * update the read_index in hv_pkt_iter_close() as + * hvs_stream_has_data() returns 0 for the FIN packet, + * so it won't get dequeued. + * + * In case hvs->recv_desc points to a malformed/short + * packet, return -EIO. + */ + if (!(vsk->peer_shutdown & SEND_SHUTDOWN)) + return -EIO; + + return 0; + } hvs->recv_desc = hv_pkt_iter_first(hvs->chan); if (!hvs->recv_desc) diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 77fe5b7b066c..57f2d6ec3ffc 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -547,11 +547,18 @@ bool virtio_transport_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port) static bool virtio_transport_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid); +static bool virtio_transport_has_remote_cid(struct vsock_sock *vsk, u32 cid) +{ + /* The CID could be implemented by the host. Always assume it is. */ + return true; +} + static struct virtio_transport virtio_transport = { .transport = { .module = THIS_MODULE, .get_local_cid = virtio_transport_get_local_cid, + .has_remote_cid = virtio_transport_has_remote_cid, .init = virtio_transport_do_socket_init, .destruct = virtio_transport_destruct, diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 8a9fb23c6e85..416d533f493d 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -60,8 +60,6 @@ static bool virtio_transport_can_zcopy(const struct virtio_transport *t_ops, return false; /* Check that transport can send data in zerocopy mode. */ - t_ops = virtio_transport_get_ops(info->vsk); - if (t_ops->can_msgzerocopy) { int pages_to_send = iov_iter_npages(iov_iter, MAX_SKB_FRAGS); @@ -75,6 +73,7 @@ static bool virtio_transport_can_zcopy(const struct virtio_transport *t_ops, static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk, struct sk_buff *skb, struct msghdr *msg, + size_t pkt_len, bool zerocopy) { struct ubuf_info *uarg; @@ -83,12 +82,10 @@ static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk, uarg = msg->msg_ubuf; net_zcopy_get(uarg); } else { - struct iov_iter *iter = &msg->msg_iter; struct ubuf_info_msgzc *uarg_zc; uarg = msg_zerocopy_realloc(sk_vsock(vsk), - iter->count, - NULL, false); + pkt_len, NULL, false); if (!uarg) return -1; @@ -400,11 +397,17 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, * each iteration. If this is last skb for this buffer * and MSG_ZEROCOPY mode is in use - we must allocate * completion for the current syscall. + * + * Pass pkt_len because msg iter is already consumed + * by virtio_transport_fill_skb(), so iter->count + * can not be used for RLIMIT_MEMLOCK pinned-pages + * accounting done by msg_zerocopy_realloc(). */ if (info->msg && info->msg->msg_flags & MSG_ZEROCOPY && skb_len == rest_len && info->op == VIRTIO_VSOCK_OP_RW) { if (virtio_transport_init_zcopy_skb(vsk, skb, info->msg, + pkt_len, can_zcopy)) { kfree_skb(skb); ret = -ENOMEM; @@ -547,9 +550,8 @@ virtio_transport_stream_do_peek(struct vsock_sock *vsk, skb_queue_walk(&vvs->rx_queue, skb) { size_t bytes; - bytes = len - total; - if (bytes > skb->len) - bytes = skb->len; + bytes = min_t(size_t, len - total, + skb->len - VIRTIO_VSOCK_SKB_CB(skb)->offset); spin_unlock_bh(&vvs->rx_lock); @@ -1560,8 +1562,6 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, return -ENOMEM; } - sk_acceptq_added(sk); - lock_sock_nested(child, SINGLE_DEPTH_NESTING); child->sk_state = TCP_ESTABLISHED; @@ -1583,6 +1583,7 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, return ret; } + sk_acceptq_added(sk); if (virtio_transport_space_update(child, skb)) child->sk_write_space(child); diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c index 07b96d56f3a5..9049d2648646 100644 --- a/net/vmw_vsock/vsock_bpf.c +++ b/net/vmw_vsock/vsock_bpf.c @@ -74,7 +74,7 @@ static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int } static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg, - size_t len, int flags, int *addr_len) + size_t len, int flags) { struct sk_psock *psock; struct vsock_sock *vsk; |
