diff options
Diffstat (limited to 'net')
| -rw-r--r-- | net/Kconfig | 1 | ||||
| -rw-r--r-- | net/Makefile | 1 | ||||
| -rw-r--r-- | net/core/dev.c | 32 | ||||
| -rw-r--r-- | net/core/gro.c | 2 | ||||
| -rw-r--r-- | net/core/skbuff.c | 4 | ||||
| -rw-r--r-- | net/ipv4/af_inet.c | 2 | ||||
| -rw-r--r-- | net/ipv4/inet_timewait_sock.c | 5 | ||||
| -rw-r--r-- | net/ipv4/ip_output.c | 5 | ||||
| -rw-r--r-- | net/ipv4/tcp.c | 2 | ||||
| -rw-r--r-- | net/ipv4/tcp_ipv4.c | 18 | ||||
| -rw-r--r-- | net/ipv4/tcp_minisocks.c | 20 | ||||
| -rw-r--r-- | net/ipv4/tcp_output.c | 17 | ||||
| -rw-r--r-- | net/ipv6/ipv6_sockglue.c | 6 | ||||
| -rw-r--r-- | net/ipv6/tcp_ipv6.c | 17 | ||||
| -rw-r--r-- | net/psp/Kconfig | 15 | ||||
| -rw-r--r-- | net/psp/Makefile | 5 | ||||
| -rw-r--r-- | net/psp/psp-nl-gen.c | 119 | ||||
| -rw-r--r-- | net/psp/psp-nl-gen.h | 39 | ||||
| -rw-r--r-- | net/psp/psp.h | 54 | ||||
| -rw-r--r-- | net/psp/psp_main.c | 321 | ||||
| -rw-r--r-- | net/psp/psp_nl.c | 505 | ||||
| -rw-r--r-- | net/psp/psp_sock.c | 295 |
22 files changed, 1471 insertions, 14 deletions
diff --git a/net/Kconfig b/net/Kconfig index d5865cf19799..4b563aea4c23 100644 --- a/net/Kconfig +++ b/net/Kconfig @@ -82,6 +82,7 @@ config NET_CRC32C menu "Networking options" source "net/packet/Kconfig" +source "net/psp/Kconfig" source "net/unix/Kconfig" source "net/tls/Kconfig" source "net/xfrm/Kconfig" diff --git a/net/Makefile b/net/Makefile index aac960c41db6..90e3d72bf58b 100644 --- a/net/Makefile +++ b/net/Makefile @@ -18,6 +18,7 @@ obj-$(CONFIG_INET) += ipv4/ obj-$(CONFIG_TLS) += tls/ obj-$(CONFIG_XFRM) += xfrm/ obj-$(CONFIG_UNIX) += unix/ +obj-$(CONFIG_INET_PSP) += psp/ obj-y += ipv6/ obj-$(CONFIG_PACKET) += packet/ obj-$(CONFIG_NET_KEY) += key/ diff --git a/net/core/dev.c b/net/core/dev.c index 2522d9d8f0e4..5e22d062bac5 100644 --- a/net/core/dev.c +++ b/net/core/dev.c @@ -3907,6 +3907,38 @@ sw_checksum: } EXPORT_SYMBOL(skb_csum_hwoffload_help); +/* Checks if this SKB belongs to an HW offloaded socket + * and whether any SW fallbacks are required based on dev. + * Check decrypted mark in case skb_orphan() cleared socket. + */ +static struct sk_buff *sk_validate_xmit_skb(struct sk_buff *skb, + struct net_device *dev) +{ +#ifdef CONFIG_SOCK_VALIDATE_XMIT + struct sk_buff *(*sk_validate)(struct sock *sk, struct net_device *dev, + struct sk_buff *skb); + struct sock *sk = skb->sk; + + sk_validate = NULL; + if (sk) { + if (sk_fullsock(sk)) + sk_validate = sk->sk_validate_xmit_skb; + else if (sk_is_inet(sk) && sk->sk_state == TCP_TIME_WAIT) + sk_validate = inet_twsk(sk)->tw_validate_xmit_skb; + } + + if (sk_validate) { + skb = sk_validate(sk, dev, skb); + } else if (unlikely(skb_is_decrypted(skb))) { + pr_warn_ratelimited("unencrypted skb with no associated socket - dropping\n"); + kfree_skb(skb); + skb = NULL; + } +#endif + + return skb; +} + static struct sk_buff *validate_xmit_unreadable_skb(struct sk_buff *skb, struct net_device *dev) { diff --git a/net/core/gro.c b/net/core/gro.c index b350e5b69549..5ba4504cfd28 100644 --- a/net/core/gro.c +++ b/net/core/gro.c @@ -1,4 +1,5 @@ // SPDX-License-Identifier: GPL-2.0-or-later +#include <net/psp.h> #include <net/gro.h> #include <net/dst_metadata.h> #include <net/busy_poll.h> @@ -376,6 +377,7 @@ static void gro_list_prepare(const struct list_head *head, diffs |= skb_get_nfct(p) ^ skb_get_nfct(skb); diffs |= gro_list_prepare_tc_ext(skb, p, diffs); + diffs |= __psp_skb_coalesce_diff(skb, p, diffs); } NAPI_GRO_CB(p)->same_flow = !diffs; diff --git a/net/core/skbuff.c b/net/core/skbuff.c index 23b776cd9879..d331e607edfb 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -79,6 +79,7 @@ #include <net/mptcp.h> #include <net/mctp.h> #include <net/page_pool/helpers.h> +#include <net/psp/types.h> #include <net/dropreason.h> #include <linux/uaccess.h> @@ -5062,6 +5063,9 @@ static const u8 skb_ext_type_len[] = { #if IS_ENABLED(CONFIG_MCTP_FLOWS) [SKB_EXT_MCTP] = SKB_EXT_CHUNKSIZEOF(struct mctp_flow), #endif +#if IS_ENABLED(CONFIG_INET_PSP) + [SKB_EXT_PSP] = SKB_EXT_CHUNKSIZEOF(struct psp_skb_ext), +#endif }; static __always_inline unsigned int skb_ext_total_length(void) diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 76e38092cd8a..e298dacb4a06 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -102,6 +102,7 @@ #include <net/gro.h> #include <net/gso.h> #include <net/tcp.h> +#include <net/psp.h> #include <net/udp.h> #include <net/udplite.h> #include <net/ping.h> @@ -158,6 +159,7 @@ void inet_sock_destruct(struct sock *sk) kfree(rcu_dereference_protected(inet->inet_opt, 1)); dst_release(rcu_dereference_protected(sk->sk_dst_cache, 1)); dst_release(rcu_dereference_protected(sk->sk_rx_dst, 1)); + psp_sk_assoc_free(sk); } EXPORT_SYMBOL(inet_sock_destruct); diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c index 5b5426b8ee92..2ca2912f61f4 100644 --- a/net/ipv4/inet_timewait_sock.c +++ b/net/ipv4/inet_timewait_sock.c @@ -16,6 +16,7 @@ #include <net/inet_timewait_sock.h> #include <net/ip.h> #include <net/tcp.h> +#include <net/psp.h> /** * inet_twsk_bind_unhash - unhash a timewait socket from bind hash @@ -211,6 +212,9 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, atomic64_set(&tw->tw_cookie, atomic64_read(&sk->sk_cookie)); twsk_net_set(tw, sock_net(sk)); timer_setup(&tw->tw_timer, tw_timer_handler, 0); +#ifdef CONFIG_SOCK_VALIDATE_XMIT + tw->tw_validate_xmit_skb = NULL; +#endif /* * Because we use RCU lookups, we should not set tw_refcnt * to a non null value before everything is setup for this @@ -219,6 +223,7 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, refcount_set(&tw->tw_refcnt, 0); __module_get(tw->tw_prot->owner); + psp_twsk_init(tw, sk); } return tw; diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c index 2b96651d719b..5ca97ede979c 100644 --- a/net/ipv4/ip_output.c +++ b/net/ipv4/ip_output.c @@ -84,6 +84,7 @@ #include <linux/netfilter_bridge.h> #include <linux/netlink.h> #include <linux/tcp.h> +#include <net/psp.h> static int ip_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, @@ -1665,8 +1666,10 @@ void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk, arg->csumoffset) = csum_fold(csum_add(nskb->csum, arg->csum)); nskb->ip_summed = CHECKSUM_NONE; - if (orig_sk) + if (orig_sk) { skb_set_owner_edemux(nskb, (struct sock *)orig_sk); + psp_reply_set_decrypted(nskb); + } if (transmit_time) nskb->tstamp_type = SKB_CLOCK_MONOTONIC; if (txhash) diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 5b5c655ded1d..d6d0d970e014 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -277,6 +277,7 @@ #include <net/proto_memory.h> #include <net/xfrm.h> #include <net/ip.h> +#include <net/psp.h> #include <net/sock.h> #include <net/rstreason.h> @@ -705,6 +706,7 @@ void tcp_skb_entail(struct sock *sk, struct sk_buff *skb) tcb->seq = tcb->end_seq = tp->write_seq; tcb->tcp_flags = TCPHDR_ACK; __skb_header_release(skb); + psp_enqueue_set_decrypted(sk, skb); tcp_add_write_queue_tail(sk, skb); sk_wmem_queued_add(sk, skb->truesize); sk_mem_charge(sk, skb->truesize); diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 6a63be1f6461..b1fcf3e4e1ce 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -75,6 +75,7 @@ #include <net/secure_seq.h> #include <net/busy_poll.h> #include <net/rstreason.h> +#include <net/psp.h> #include <linux/inet.h> #include <linux/ipv6.h> @@ -293,9 +294,9 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) inet->inet_dport = usin->sin_port; sk_daddr_set(sk, daddr); - inet_csk(sk)->icsk_ext_hdr_len = 0; + inet_csk(sk)->icsk_ext_hdr_len = psp_sk_overhead(sk); if (inet_opt) - inet_csk(sk)->icsk_ext_hdr_len = inet_opt->opt.optlen; + inet_csk(sk)->icsk_ext_hdr_len += inet_opt->opt.optlen; tp->rx_opt.mss_clamp = TCP_MSS_DEFAULT; @@ -1907,6 +1908,10 @@ int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb) enum skb_drop_reason reason; struct sock *rsk; + reason = psp_sk_rx_policy_check(sk, skb); + if (reason) + goto err_discard; + if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */ struct dst_entry *dst; @@ -1968,6 +1973,7 @@ csum_err: reason = SKB_DROP_REASON_TCP_CSUM; trace_tcp_bad_csum(skb); TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); +err_discard: TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); goto discard; } @@ -2069,7 +2075,9 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb, (TCPHDR_ECE | TCPHDR_CWR | TCPHDR_AE)) || !tcp_skb_can_collapse_rx(tail, skb) || thtail->doff != th->doff || - memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th))) + memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th)) || + /* prior to PSP Rx policy check, retain exact PSP metadata */ + psp_skb_coalesce_diff(tail, skb)) goto no_coalesce; __skb_pull(skb, hdrlen); @@ -2437,6 +2445,10 @@ do_time_wait: __this_cpu_write(tcp_tw_isn, isn); goto process; } + + drop_reason = psp_twsk_rx_policy_check(inet_twsk(sk), skb); + if (drop_reason) + break; } /* to ACK */ fallthrough; diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index 327095ef95ef..2ec8c6f1cdcc 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -24,6 +24,7 @@ #include <net/xfrm.h> #include <net/busy_poll.h> #include <net/rstreason.h> +#include <net/psp.h> static bool tcp_in_window(u32 seq, u32 end_seq, u32 s_win, u32 e_win) { @@ -104,9 +105,16 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb, struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw); u32 rcv_nxt = READ_ONCE(tcptw->tw_rcv_nxt); struct tcp_options_received tmp_opt; + enum skb_drop_reason psp_drop; bool paws_reject = false; int ts_recent_stamp; + /* Instead of dropping immediately, wait to see what value is + * returned. We will accept a non psp-encapsulated syn in the + * case where TCP_TW_SYN is returned. + */ + psp_drop = psp_twsk_rx_policy_check(tw, skb); + tmp_opt.saw_tstamp = 0; ts_recent_stamp = READ_ONCE(tcptw->tw_ts_recent_stamp); if (th->doff > (sizeof(*th) >> 2) && ts_recent_stamp) { @@ -124,6 +132,9 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb, if (READ_ONCE(tw->tw_substate) == TCP_FIN_WAIT2) { /* Just repeat all the checks of tcp_rcv_state_process() */ + if (psp_drop) + goto out_put; + /* Out of window, send ACK */ if (paws_reject || !tcp_in_window(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, @@ -194,6 +205,9 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb, (TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq || th->rst))) { /* In window segment, it may be only reset or bare ack. */ + if (psp_drop) + goto out_put; + if (th->rst) { /* This is TIME_WAIT assassination, in two flavors. * Oh well... nobody has a sufficient solution to this @@ -247,6 +261,9 @@ kill: return TCP_TW_SYN; } + if (psp_drop) + goto out_put; + if (paws_reject) { *drop_reason = SKB_DROP_REASON_TCP_RFC7323_TW_PAWS; __NET_INC_STATS(twsk_net(tw), LINUX_MIB_PAWS_TW_REJECTED); @@ -265,6 +282,8 @@ kill: return tcp_timewait_check_oow_rate_limit( tw, skb, LINUX_MIB_TCPACKSKIPPEDTIMEWAIT); } + +out_put: inet_twsk_put(tw); return TCP_TW_SUCCESS; } @@ -392,6 +411,7 @@ void tcp_twsk_destructor(struct sock *sk) } #endif tcp_ao_destroy_sock(sk, true); + psp_twsk_assoc_free(inet_twsk(sk)); } void tcp_twsk_purge(struct list_head *net_exit_list) diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 388c45859469..223d7feeb19d 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -41,6 +41,7 @@ #include <net/tcp_ecn.h> #include <net/mptcp.h> #include <net/proto_memory.h> +#include <net/psp.h> #include <linux/compiler.h> #include <linux/gfp.h> @@ -358,13 +359,15 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, /* Constructs common control bits of non-data skb. If SYN/FIN is present, * auto increment end seqno. */ -static void tcp_init_nondata_skb(struct sk_buff *skb, u32 seq, u16 flags) +static void tcp_init_nondata_skb(struct sk_buff *skb, struct sock *sk, + u32 seq, u16 flags) { skb->ip_summed = CHECKSUM_PARTIAL; TCP_SKB_CB(skb)->tcp_flags = flags; tcp_skb_pcount_set(skb, 1); + psp_enqueue_set_decrypted(sk, skb); TCP_SKB_CB(skb)->seq = seq; if (flags & (TCPHDR_SYN | TCPHDR_FIN)) @@ -1656,6 +1659,7 @@ static void tcp_queue_skb(struct sock *sk, struct sk_buff *skb) /* Advance write_seq and place onto the write_queue. */ WRITE_ONCE(tp->write_seq, TCP_SKB_CB(skb)->end_seq); __skb_header_release(skb); + psp_enqueue_set_decrypted(sk, skb); tcp_add_write_queue_tail(sk, skb); sk_wmem_queued_add(sk, skb->truesize); sk_mem_charge(sk, skb->truesize); @@ -3778,7 +3782,7 @@ void tcp_send_fin(struct sock *sk) skb_reserve(skb, MAX_TCP_HEADER); sk_forced_mem_schedule(sk, skb->truesize); /* FIN eats a sequence byte, write_seq advanced by tcp_queue_skb(). */ - tcp_init_nondata_skb(skb, tp->write_seq, + tcp_init_nondata_skb(skb, sk, tp->write_seq, TCPHDR_ACK | TCPHDR_FIN); tcp_queue_skb(sk, skb); } @@ -3806,7 +3810,7 @@ void tcp_send_active_reset(struct sock *sk, gfp_t priority, /* Reserve space for headers and prepare control bits. */ skb_reserve(skb, MAX_TCP_HEADER); - tcp_init_nondata_skb(skb, tcp_acceptable_seq(sk), + tcp_init_nondata_skb(skb, sk, tcp_acceptable_seq(sk), TCPHDR_ACK | TCPHDR_RST); tcp_mstamp_refresh(tcp_sk(sk)); /* Send it off. */ @@ -4303,7 +4307,7 @@ int tcp_connect(struct sock *sk) /* SYN eats a sequence byte, write_seq updated by * tcp_connect_queue_skb(). */ - tcp_init_nondata_skb(buff, tp->write_seq, TCPHDR_SYN); + tcp_init_nondata_skb(buff, sk, tp->write_seq, TCPHDR_SYN); tcp_mstamp_refresh(tp); tp->retrans_stamp = tcp_time_stamp_ts(tp); tcp_connect_queue_skb(sk, buff); @@ -4428,7 +4432,8 @@ void __tcp_send_ack(struct sock *sk, u32 rcv_nxt, u16 flags) /* Reserve space for headers and prepare control bits. */ skb_reserve(buff, MAX_TCP_HEADER); - tcp_init_nondata_skb(buff, tcp_acceptable_seq(sk), TCPHDR_ACK | flags); + tcp_init_nondata_skb(buff, sk, + tcp_acceptable_seq(sk), TCPHDR_ACK | flags); /* We do not want pure acks influencing TCP Small Queues or fq/pacing * too much. @@ -4474,7 +4479,7 @@ static int tcp_xmit_probe_skb(struct sock *sk, int urgent, int mib) * end to send an ack. Don't queue or clone SKB, just * send it. */ - tcp_init_nondata_skb(skb, tp->snd_una - !urgent, TCPHDR_ACK); + tcp_init_nondata_skb(skb, sk, tp->snd_una - !urgent, TCPHDR_ACK); NET_INC_STATS(sock_net(sk), mib); return tcp_transmit_skb(sk, skb, 0, (__force gfp_t)0); } diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c index e66ec623972e..a61e742794f9 100644 --- a/net/ipv6/ipv6_sockglue.c +++ b/net/ipv6/ipv6_sockglue.c @@ -49,6 +49,7 @@ #include <net/xfrm.h> #include <net/compat.h> #include <net/seg6.h> +#include <net/psp.h> #include <linux/uaccess.h> @@ -107,7 +108,10 @@ struct ipv6_txoptions *ipv6_update_options(struct sock *sk, !((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) && inet_sk(sk)->inet_daddr != LOOPBACK4_IPV6) { struct inet_connection_sock *icsk = inet_csk(sk); - icsk->icsk_ext_hdr_len = opt->opt_flen + opt->opt_nflen; + + icsk->icsk_ext_hdr_len = + psp_sk_overhead(sk) + + opt->opt_flen + opt->opt_nflen; icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie); } } diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index c7271f6359db..d1e5b2a186fb 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -62,6 +62,7 @@ #include <net/hotdata.h> #include <net/busy_poll.h> #include <net/rstreason.h> +#include <net/psp.h> #include <linux/proc_fs.h> #include <linux/seq_file.h> @@ -301,10 +302,10 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, sk->sk_gso_type = SKB_GSO_TCPV6; ip6_dst_store(sk, dst, false, false); - icsk->icsk_ext_hdr_len = 0; + icsk->icsk_ext_hdr_len = psp_sk_overhead(sk); if (opt) - icsk->icsk_ext_hdr_len = opt->opt_flen + - opt->opt_nflen; + icsk->icsk_ext_hdr_len += opt->opt_flen + + opt->opt_nflen; tp->rx_opt.mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - sizeof(struct ipv6hdr); @@ -973,6 +974,7 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32 if (sk) { /* unconstify the socket only to attach it to buff with care. */ skb_set_owner_edemux(buff, (struct sock *)sk); + psp_reply_set_decrypted(buff); if (sk->sk_state == TCP_TIME_WAIT) mark = inet_twsk(sk)->tw_mark; @@ -1605,6 +1607,10 @@ int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb) if (skb->protocol == htons(ETH_P_IP)) return tcp_v4_do_rcv(sk, skb); + reason = psp_sk_rx_policy_check(sk, skb); + if (reason) + goto err_discard; + /* * socket locking is here for SMP purposes as backlog rcv * is currently called with bh processing disabled. @@ -1684,6 +1690,7 @@ csum_err: reason = SKB_DROP_REASON_TCP_CSUM; trace_tcp_bad_csum(skb); TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); +err_discard: TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); goto discard; @@ -1988,6 +1995,10 @@ do_time_wait: __this_cpu_write(tcp_tw_isn, isn); goto process; } + + drop_reason = psp_twsk_rx_policy_check(inet_twsk(sk), skb); + if (drop_reason) + break; } /* to ACK */ fallthrough; diff --git a/net/psp/Kconfig b/net/psp/Kconfig new file mode 100644 index 000000000000..a7d24691a7e1 --- /dev/null +++ b/net/psp/Kconfig @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# PSP configuration +# +config INET_PSP + bool "PSP Security Protocol support" + depends on INET + select SKB_DECRYPTED + select SOCK_VALIDATE_XMIT + help + Enable kernel support for the PSP protocol. + For more information see: + https://raw.githubusercontent.com/google/psp/main/doc/PSP_Arch_Spec.pdf + + If unsure, say N. diff --git a/net/psp/Makefile b/net/psp/Makefile new file mode 100644 index 000000000000..eb5ff3c5bfb2 --- /dev/null +++ b/net/psp/Makefile @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: GPL-2.0-only + +obj-$(CONFIG_INET_PSP) += psp.o + +psp-y := psp_main.o psp_nl.o psp_sock.o psp-nl-gen.o diff --git a/net/psp/psp-nl-gen.c b/net/psp/psp-nl-gen.c new file mode 100644 index 000000000000..9fdd6f831803 --- /dev/null +++ b/net/psp/psp-nl-gen.c @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/psp.yaml */ +/* YNL-GEN kernel source */ + +#include <net/netlink.h> +#include <net/genetlink.h> + +#include "psp-nl-gen.h" + +#include <uapi/linux/psp.h> + +/* Common nested types */ +const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1] = { + [PSP_A_KEYS_KEY] = { .type = NLA_BINARY, }, + [PSP_A_KEYS_SPI] = { .type = NLA_U32, }, +}; + +/* PSP_CMD_DEV_GET - do */ +static const struct nla_policy psp_dev_get_nl_policy[PSP_A_DEV_ID + 1] = { + [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), +}; + +/* PSP_CMD_DEV_SET - do */ +static const struct nla_policy psp_dev_set_nl_policy[PSP_A_DEV_PSP_VERSIONS_ENA + 1] = { + [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [PSP_A_DEV_PSP_VERSIONS_ENA] = NLA_POLICY_MASK(NLA_U32, 0xf), +}; + +/* PSP_CMD_KEY_ROTATE - do */ +static const struct nla_policy psp_key_rotate_nl_policy[PSP_A_DEV_ID + 1] = { + [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), +}; + +/* PSP_CMD_RX_ASSOC - do */ +static const struct nla_policy psp_rx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = { + [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3), + [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, }, +}; + +/* PSP_CMD_TX_ASSOC - do */ +static const struct nla_policy psp_tx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = { + [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3), + [PSP_A_ASSOC_TX_KEY] = NLA_POLICY_NESTED(psp_keys_nl_policy), + [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, }, +}; + +/* Ops table for psp */ +static const struct genl_split_ops psp_nl_ops[] = { + { + .cmd = PSP_CMD_DEV_GET, + .pre_doit = psp_device_get_locked, + .doit = psp_nl_dev_get_doit, + .post_doit = psp_device_unlock, + .policy = psp_dev_get_nl_policy, + .maxattr = PSP_A_DEV_ID, + .flags = GENL_CMD_CAP_DO, + }, + { + .cmd = PSP_CMD_DEV_GET, + .dumpit = psp_nl_dev_get_dumpit, + .flags = GENL_CMD_CAP_DUMP, + }, + { + .cmd = PSP_CMD_DEV_SET, + .pre_doit = psp_device_get_locked, + .doit = psp_nl_dev_set_doit, + .post_doit = psp_device_unlock, + .policy = psp_dev_set_nl_policy, + .maxattr = PSP_A_DEV_PSP_VERSIONS_ENA, + .flags = GENL_CMD_CAP_DO, + }, + { + .cmd = PSP_CMD_KEY_ROTATE, + .pre_doit = psp_device_get_locked, + .doit = psp_nl_key_rotate_doit, + .post_doit = psp_device_unlock, + .policy = psp_key_rotate_nl_policy, + .maxattr = PSP_A_DEV_ID, + .flags = GENL_CMD_CAP_DO, + }, + { + .cmd = PSP_CMD_RX_ASSOC, + .pre_doit = psp_assoc_device_get_locked, + .doit = psp_nl_rx_assoc_doit, + .post_doit = psp_device_unlock, + .policy = psp_rx_assoc_nl_policy, + .maxattr = PSP_A_ASSOC_SOCK_FD, + .flags = GENL_CMD_CAP_DO, + }, + { + .cmd = PSP_CMD_TX_ASSOC, + .pre_doit = psp_assoc_device_get_locked, + .doit = psp_nl_tx_assoc_doit, + .post_doit = psp_device_unlock, + .policy = psp_tx_assoc_nl_policy, + .maxattr = PSP_A_ASSOC_SOCK_FD, + .flags = GENL_CMD_CAP_DO, + }, +}; + +static const struct genl_multicast_group psp_nl_mcgrps[] = { + [PSP_NLGRP_MGMT] = { "mgmt", }, + [PSP_NLGRP_USE] = { "use", }, +}; + +struct genl_family psp_nl_family __ro_after_init = { + .name = PSP_FAMILY_NAME, + .version = PSP_FAMILY_VERSION, + .netnsok = true, + .parallel_ops = true, + .module = THIS_MODULE, + .split_ops = psp_nl_ops, + .n_split_ops = ARRAY_SIZE(psp_nl_ops), + .mcgrps = psp_nl_mcgrps, + .n_mcgrps = ARRAY_SIZE(psp_nl_mcgrps), +}; diff --git a/net/psp/psp-nl-gen.h b/net/psp/psp-nl-gen.h new file mode 100644 index 000000000000..25268ed11fb5 --- /dev/null +++ b/net/psp/psp-nl-gen.h @@ -0,0 +1,39 @@ +/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */ +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/psp.yaml */ +/* YNL-GEN kernel header */ + +#ifndef _LINUX_PSP_GEN_H +#define _LINUX_PSP_GEN_H + +#include <net/netlink.h> +#include <net/genetlink.h> + +#include <uapi/linux/psp.h> + +/* Common nested types */ +extern const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1]; + +int psp_device_get_locked(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info); +int psp_assoc_device_get_locked(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info); +void +psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, + struct genl_info *info); + +int psp_nl_dev_get_doit(struct sk_buff *skb, struct genl_info *info); +int psp_nl_dev_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb); +int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info); +int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info); +int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info); +int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info); + +enum { + PSP_NLGRP_MGMT, + PSP_NLGRP_USE, +}; + +extern struct genl_family psp_nl_family; + +#endif /* _LINUX_PSP_GEN_H */ diff --git a/net/psp/psp.h b/net/psp/psp.h new file mode 100644 index 000000000000..0f34e1a23fdd --- /dev/null +++ b/net/psp/psp.h @@ -0,0 +1,54 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ + +#ifndef __PSP_PSP_H +#define __PSP_PSP_H + +#include <linux/list.h> +#include <linux/lockdep.h> +#include <linux/mutex.h> +#include <net/netns/generic.h> +#include <net/psp.h> +#include <net/sock.h> + +extern struct xarray psp_devs; +extern struct mutex psp_devs_lock; + +void psp_dev_destroy(struct psp_dev *psd); +int psp_dev_check_access(struct psp_dev *psd, struct net *net); + +void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd); + +struct psp_assoc *psp_assoc_create(struct psp_dev *psd); +struct psp_dev *psp_dev_get_for_sock(struct sock *sk); +void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas); +int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, + struct psp_key_parsed *key, + struct netlink_ext_ack *extack); +int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, + u32 version, struct psp_key_parsed *key, + struct netlink_ext_ack *extack); +void psp_assocs_key_rotated(struct psp_dev *psd); + +static inline void psp_dev_get(struct psp_dev *psd) +{ + refcount_inc(&psd->refcnt); +} + +static inline bool psp_dev_tryget(struct psp_dev *psd) +{ + return refcount_inc_not_zero(&psd->refcnt); +} + +static inline void psp_dev_put(struct psp_dev *psd) +{ + if (refcount_dec_and_test(&psd->refcnt)) + psp_dev_destroy(psd); +} + +static inline bool psp_dev_is_registered(struct psp_dev *psd) +{ + lockdep_assert_held(&psd->lock); + return !!psd->ops; +} + +#endif /* __PSP_PSP_H */ diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c new file mode 100644 index 000000000000..b4b756f87382 --- /dev/null +++ b/net/psp/psp_main.c @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include <linux/bitfield.h> +#include <linux/list.h> +#include <linux/netdevice.h> +#include <linux/xarray.h> +#include <net/net_namespace.h> +#include <net/psp.h> +#include <net/udp.h> + +#include "psp.h" +#include "psp-nl-gen.h" + +DEFINE_XARRAY_ALLOC1(psp_devs); +struct mutex psp_devs_lock; + +/** + * DOC: PSP locking + * + * psp_devs_lock protects the psp_devs xarray. + * Ordering is take the psp_devs_lock and then the instance lock. + * Each instance is protected by RCU, and has a refcount. + * When driver unregisters the instance gets flushed, but struct sticks around. + */ + +/** + * psp_dev_check_access() - check if user in a given net ns can access PSP dev + * @psd: PSP device structure user is trying to access + * @net: net namespace user is in + * + * Return: 0 if PSP device should be visible in @net, errno otherwise. + */ +int psp_dev_check_access(struct psp_dev *psd, struct net *net) +{ + if (dev_net(psd->main_netdev) == net) + return 0; + return -ENOENT; +} + +/** + * psp_dev_create() - create and register PSP device + * @netdev: main netdevice + * @psd_ops: driver callbacks + * @psd_caps: device capabilities + * @priv_ptr: back-pointer to driver private data + * + * Return: pointer to allocated PSP device, or ERR_PTR. + */ +struct psp_dev * +psp_dev_create(struct net_device *netdev, + struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps, + void *priv_ptr) +{ + struct psp_dev *psd; + static u32 last_id; + int err; + + if (WARN_ON(!psd_caps->versions || + !psd_ops->set_config || + !psd_ops->key_rotate || + !psd_ops->rx_spi_alloc || + !psd_ops->tx_key_add || + !psd_ops->tx_key_del)) + return ERR_PTR(-EINVAL); + + psd = kzalloc(sizeof(*psd), GFP_KERNEL); + if (!psd) + return ERR_PTR(-ENOMEM); + + psd->main_netdev = netdev; + psd->ops = psd_ops; + psd->caps = psd_caps; + psd->drv_priv = priv_ptr; + + mutex_init(&psd->lock); + INIT_LIST_HEAD(&psd->active_assocs); + INIT_LIST_HEAD(&psd->prev_assocs); + INIT_LIST_HEAD(&psd->stale_assocs); + refcount_set(&psd->refcnt, 1); + + mutex_lock(&psp_devs_lock); + err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b, + &last_id, GFP_KERNEL); + if (err) { + mutex_unlock(&psp_devs_lock); + kfree(psd); + return ERR_PTR(err); + } + mutex_lock(&psd->lock); + mutex_unlock(&psp_devs_lock); + + psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF); + + rcu_assign_pointer(netdev->psp_dev, psd); + + mutex_unlock(&psd->lock); + + return psd; +} +EXPORT_SYMBOL(psp_dev_create); + +void psp_dev_destroy(struct psp_dev *psd) +{ + mutex_lock(&psp_devs_lock); + xa_erase(&psp_devs, psd->id); + mutex_unlock(&psp_devs_lock); + + mutex_destroy(&psd->lock); + kfree_rcu(psd, rcu); +} + +/** + * psp_dev_unregister() - unregister PSP device + * @psd: PSP device structure + */ +void psp_dev_unregister(struct psp_dev *psd) +{ + struct psp_assoc *pas, *next; + + mutex_lock(&psp_devs_lock); + mutex_lock(&psd->lock); + + psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF); + + /* Wait until psp_dev_destroy() to call xa_erase() to prevent a + * different psd from being added to the xarray with this id, while + * there are still references to this psd being held. + */ + xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); + mutex_unlock(&psp_devs_lock); + + list_splice_init(&psd->active_assocs, &psd->prev_assocs); + list_splice_init(&psd->prev_assocs, &psd->stale_assocs); + list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) + psp_dev_tx_key_del(psd, pas); + + rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); + + psd->ops = NULL; + psd->drv_priv = NULL; + + mutex_unlock(&psd->lock); + + psp_dev_put(psd); +} +EXPORT_SYMBOL(psp_dev_unregister); + +unsigned int psp_key_size(u32 version) +{ + switch (version) { + case PSP_VERSION_HDR0_AES_GCM_128: + case PSP_VERSION_HDR0_AES_GMAC_128: + return 16; + case PSP_VERSION_HDR0_AES_GCM_256: + case PSP_VERSION_HDR0_AES_GMAC_256: + return 32; + default: + return 0; + } +} +EXPORT_SYMBOL(psp_key_size); + +static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi, + u8 ver, unsigned int udp_len, __be16 sport) +{ + struct udphdr *uh = udp_hdr(skb); + struct psphdr *psph = (struct psphdr *)(uh + 1); + + uh->dest = htons(PSP_DEFAULT_UDP_PORT); + uh->source = udp_flow_src_port(net, skb, 0, 0, false); + uh->check = 0; + uh->len = htons(udp_len); + + psph->nexthdr = IPPROTO_TCP; + psph->hdrlen = PSP_HDRLEN_NOOPT; + psph->crypt_offset = 0; + psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) | + FIELD_PREP(PSPHDR_VERFL_ONE, 1); + psph->spi = spi; + memset(&psph->iv, 0, sizeof(psph->iv)); +} + +/* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling + * them in. + */ +bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, + u8 ver, __be16 sport) +{ + u32 network_len = skb_network_header_len(skb); + u32 ethr_len = skb_mac_header_len(skb); + u32 bufflen = ethr_len + network_len; + + if (skb_cow_head(skb, PSP_ENCAP_HLEN)) + return false; + + skb_push(skb, PSP_ENCAP_HLEN); + skb->mac_header -= PSP_ENCAP_HLEN; + skb->network_header -= PSP_ENCAP_HLEN; + skb->transport_header -= PSP_ENCAP_HLEN; + memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen); + + if (skb->protocol == htons(ETH_P_IP)) { + ip_hdr(skb)->protocol = IPPROTO_UDP; + be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN); + ip_hdr(skb)->check = 0; + ip_hdr(skb)->check = + ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl); + } else if (skb->protocol == htons(ETH_P_IPV6)) { + ipv6_hdr(skb)->nexthdr = IPPROTO_UDP; + be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN); + } else { + return false; + } + + skb_set_inner_ipproto(skb, IPPROTO_TCP); + skb_set_inner_transport_header(skb, skb_transport_offset(skb) + + PSP_ENCAP_HLEN); + skb->encapsulation = 1; + psp_write_headers(net, skb, spi, ver, + skb->len - skb_transport_offset(skb), sport); + + return true; +} +EXPORT_SYMBOL(psp_dev_encapsulate); + +/* Receive handler for PSP packets. + * + * Presently it accepts only already-authenticated packets and does not + * support optional fields, such as virtualization cookies. The caller should + * ensure that skb->data is pointing to the mac header, and that skb->mac_len + * is set. + */ +int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) +{ + int l2_hlen = 0, l3_hlen, encap; + struct psp_skb_ext *pse; + struct psphdr *psph; + struct ethhdr *eth; + struct udphdr *uh; + __be16 proto; + bool is_udp; + + eth = (struct ethhdr *)skb->data; + proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen); + if (proto == htons(ETH_P_IP)) + l3_hlen = sizeof(struct iphdr); + else if (proto == htons(ETH_P_IPV6)) + l3_hlen = sizeof(struct ipv6hdr); + else + return -EINVAL; + + if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))) + return -EINVAL; + + if (proto == htons(ETH_P_IP)) { + struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); + + is_udp = iph->protocol == IPPROTO_UDP; + l3_hlen = iph->ihl * 4; + if (l3_hlen != sizeof(struct iphdr) && + !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)) + return -EINVAL; + } else { + struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); + + is_udp = ipv6h->nexthdr == IPPROTO_UDP; + } + + if (unlikely(!is_udp)) + return -EINVAL; + + uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen); + if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT))) + return -EINVAL; + + pse = skb_ext_add(skb, SKB_EXT_PSP); + if (!pse) + return -EINVAL; + + psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen + + sizeof(struct udphdr)); + pse->spi = psph->spi; + pse->dev_id = dev_id; + pse->generation = generation; + pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl); + + encap = PSP_ENCAP_HLEN; + encap += strip_icv ? PSP_TRL_SIZE : 0; + + if (proto == htons(ETH_P_IP)) { + struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen); + + iph->protocol = psph->nexthdr; + iph->tot_len = htons(ntohs(iph->tot_len) - encap); + iph->check = 0; + iph->check = ip_fast_csum((u8 *)iph, iph->ihl); + } else { + struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen); + + ipv6h->nexthdr = psph->nexthdr; + ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap); + } + + memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen); + skb_pull(skb, PSP_ENCAP_HLEN); + + if (strip_icv) + pskb_trim(skb, skb->len - PSP_TRL_SIZE); + + return 0; +} +EXPORT_SYMBOL(psp_dev_rcv); + +static int __init psp_init(void) +{ + mutex_init(&psp_devs_lock); + + return genl_register_family(&psp_nl_family); +} + +subsys_initcall(psp_init); diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c new file mode 100644 index 000000000000..8aaca62744c3 --- /dev/null +++ b/net/psp/psp_nl.c @@ -0,0 +1,505 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include <linux/skbuff.h> +#include <linux/xarray.h> +#include <net/genetlink.h> +#include <net/psp.h> +#include <net/sock.h> + +#include "psp-nl-gen.h" +#include "psp.h" + +/* Netlink helpers */ + +static struct sk_buff *psp_nl_reply_new(struct genl_info *info) +{ + struct sk_buff *rsp; + void *hdr; + + rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!rsp) + return NULL; + + hdr = genlmsg_iput(rsp, info); + if (!hdr) { + nlmsg_free(rsp); + return NULL; + } + + return rsp; +} + +static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info) +{ + /* Note that this *only* works with a single message per skb! */ + nlmsg_end(rsp, (struct nlmsghdr *)rsp->data); + + return genlmsg_reply(rsp, info); +} + +/* Device stuff */ + +static struct psp_dev * +psp_device_get_and_lock(struct net *net, struct nlattr *dev_id) +{ + struct psp_dev *psd; + int err; + + mutex_lock(&psp_devs_lock); + psd = xa_load(&psp_devs, nla_get_u32(dev_id)); + if (!psd) { + mutex_unlock(&psp_devs_lock); + return ERR_PTR(-ENODEV); + } + + mutex_lock(&psd->lock); + mutex_unlock(&psp_devs_lock); + + err = psp_dev_check_access(psd, net); + if (err) { + mutex_unlock(&psd->lock); + return ERR_PTR(err); + } + + return psd; +} + +int psp_device_get_locked(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID)) + return -EINVAL; + + info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info), + info->attrs[PSP_A_DEV_ID]); + return PTR_ERR_OR_ZERO(info->user_ptr[0]); +} + +void +psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct socket *socket = info->user_ptr[1]; + struct psp_dev *psd = info->user_ptr[0]; + + mutex_unlock(&psd->lock); + if (socket) + sockfd_put(socket); +} + +static int +psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp, + const struct genl_info *info) +{ + void *hdr; + + hdr = genlmsg_iput(rsp, info); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || + nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) || + nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) || + nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions)) + goto err_cancel_msg; + + genlmsg_end(rsp, hdr); + return 0; + +err_cancel_msg: + genlmsg_cancel(rsp, hdr); + return -EMSGSIZE; +} + +void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd) +{ + struct genl_info info; + struct sk_buff *ntf; + + if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev), + PSP_NLGRP_MGMT)) + return; + + ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!ntf) + return; + + genl_info_init_ntf(&info, &psp_nl_family, cmd); + if (psp_nl_dev_fill(psd, ntf, &info)) { + nlmsg_free(ntf); + return; + } + + genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, + 0, PSP_NLGRP_MGMT, GFP_KERNEL); +} + +int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info) +{ + struct psp_dev *psd = info->user_ptr[0]; + struct sk_buff *rsp; + int err; + + rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!rsp) + return -ENOMEM; + + err = psp_nl_dev_fill(psd, rsp, info); + if (err) + goto err_free_msg; + + return genlmsg_reply(rsp, info); + +err_free_msg: + nlmsg_free(rsp); + return err; +} + +static int +psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb, + struct psp_dev *psd) +{ + if (psp_dev_check_access(psd, sock_net(rsp->sk))) + return 0; + + return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb)); +} + +int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb) +{ + struct psp_dev *psd; + int err = 0; + + mutex_lock(&psp_devs_lock); + xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) { + mutex_lock(&psd->lock); + err = psp_nl_dev_get_dumpit_one(rsp, cb, psd); + mutex_unlock(&psd->lock); + if (err) + break; + } + mutex_unlock(&psp_devs_lock); + + return err; +} + +int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct psp_dev *psd = info->user_ptr[0]; + struct psp_dev_config new_config; + struct sk_buff *rsp; + int err; + + memcpy(&new_config, &psd->config, sizeof(new_config)); + + if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) { + new_config.versions = + nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]); + if (new_config.versions & ~psd->caps->versions) { + NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device"); + return -EINVAL; + } + } else { + NL_SET_ERR_MSG(info->extack, "No settings present"); + return -EINVAL; + } + + rsp = psp_nl_reply_new(info); + if (!rsp) + return -ENOMEM; + + if (memcmp(&new_config, &psd->config, sizeof(new_config))) { + err = psd->ops->set_config(psd, &new_config, info->extack); + if (err) + goto err_free_rsp; + + memcpy(&psd->config, &new_config, sizeof(new_config)); + } + + psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); + + return psp_nl_reply_send(rsp, info); + +err_free_rsp: + nlmsg_free(rsp); + return err; +} + +int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct psp_dev *psd = info->user_ptr[0]; + struct genl_info ntf_info; + struct sk_buff *ntf, *rsp; + u8 prev_gen; + int err; + + rsp = psp_nl_reply_new(info); + if (!rsp) + return -ENOMEM; + + genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF); + ntf = psp_nl_reply_new(&ntf_info); + if (!ntf) { + err = -ENOMEM; + goto err_free_rsp; + } + + if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || + nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) { + err = -EMSGSIZE; + goto err_free_ntf; + } + + /* suggest the next gen number, driver can override */ + prev_gen = psd->generation; + psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK; + + err = psd->ops->key_rotate(psd, info->extack); + if (err) + goto err_free_ntf; + + WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) || + psd->generation & ~PSP_GEN_VALID_MASK); + + psp_assocs_key_rotated(psd); + + nlmsg_end(ntf, (struct nlmsghdr *)ntf->data); + genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, + 0, PSP_NLGRP_USE, GFP_KERNEL); + return psp_nl_reply_send(rsp, info); + +err_free_ntf: + nlmsg_free(ntf); +err_free_rsp: + nlmsg_free(rsp); + return err; +} + +/* Key etc. */ + +int psp_assoc_device_get_locked(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + struct socket *socket; + struct psp_dev *psd; + struct nlattr *id; + int fd, err; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD)) + return -EINVAL; + + fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]); + socket = sockfd_lookup(fd, &err); + if (!socket) + return err; + + if (!sk_is_tcp(socket->sk)) { + NL_SET_ERR_MSG_ATTR(info->extack, + info->attrs[PSP_A_ASSOC_SOCK_FD], + "Unsupported socket family and type"); + err = -EOPNOTSUPP; + goto err_sock_put; + } + + psd = psp_dev_get_for_sock(socket->sk); + if (psd) { + err = psp_dev_check_access(psd, genl_info_net(info)); + if (err) { + psp_dev_put(psd); + psd = NULL; + } + } + + if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) { + err = -EINVAL; + goto err_sock_put; + } + + id = info->attrs[PSP_A_ASSOC_DEV_ID]; + if (psd) { + mutex_lock(&psd->lock); + if (id && psd->id != nla_get_u32(id)) { + mutex_unlock(&psd->lock); + NL_SET_ERR_MSG_ATTR(info->extack, id, + "Device id vs socket mismatch"); + err = -EINVAL; + goto err_psd_put; + } + + psp_dev_put(psd); + } else { + psd = psp_device_get_and_lock(genl_info_net(info), id); + if (IS_ERR(psd)) { + err = PTR_ERR(psd); + goto err_sock_put; + } + } + + info->user_ptr[0] = psd; + info->user_ptr[1] = socket; + + return 0; + +err_psd_put: + psp_dev_put(psd); +err_sock_put: + sockfd_put(socket); + return err; +} + +static int +psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key, + unsigned int key_sz) +{ + struct nlattr *nest = info->attrs[attr]; + struct nlattr *tb[PSP_A_KEYS_SPI + 1]; + u32 spi; + int err; + + err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest, + psp_keys_nl_policy, info->extack); + if (err) + return err; + + if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) || + NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI)) + return -EINVAL; + + if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) { + NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], + "incorrect key length"); + return -EINVAL; + } + + spi = nla_get_u32(tb[PSP_A_KEYS_SPI]); + if (!(spi & PSP_SPI_KEY_ID)) { + NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], + "invalid SPI: lower 31b must be non-zero"); + return -EINVAL; + } + + key->spi = cpu_to_be32(spi); + memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz); + + return 0; +} + +static int +psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version, + struct psp_key_parsed *key) +{ + int key_sz = psp_key_size(version); + void *nest; + + nest = nla_nest_start(skb, attr); + + if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) || + nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) { + nla_nest_cancel(skb, nest); + return -EMSGSIZE; + } + + nla_nest_end(skb, nest); + + return 0; +} + +int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct socket *socket = info->user_ptr[1]; + struct psp_dev *psd = info->user_ptr[0]; + struct psp_key_parsed key; + struct psp_assoc *pas; + struct sk_buff *rsp; + u32 version; + int err; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION)) + return -EINVAL; + + version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); + if (!(psd->caps->versions & (1 << version))) { + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); + return -EOPNOTSUPP; + } + + rsp = psp_nl_reply_new(info); + if (!rsp) + return -ENOMEM; + + pas = psp_assoc_create(psd); + if (!pas) { + err = -ENOMEM; + goto err_free_rsp; + } + pas->version = version; + + err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack); + if (err) + goto err_free_pas; + + if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) || + psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) { + err = -EMSGSIZE; + goto err_free_pas; + } + + err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack); + if (err) { + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]); + goto err_free_pas; + } + psp_assoc_put(pas); + + return psp_nl_reply_send(rsp, info); + +err_free_pas: + psp_assoc_put(pas); +err_free_rsp: + nlmsg_free(rsp); + return err; +} + +int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct socket *socket = info->user_ptr[1]; + struct psp_dev *psd = info->user_ptr[0]; + struct psp_key_parsed key; + struct sk_buff *rsp; + unsigned int key_sz; + u32 version; + int err; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) || + GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY)) + return -EINVAL; + + version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); + if (!(psd->caps->versions & (1 << version))) { + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); + return -EOPNOTSUPP; + } + + key_sz = psp_key_size(version); + if (!key_sz) + return -EINVAL; + + err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz); + if (err < 0) + return err; + + rsp = psp_nl_reply_new(info); + if (!rsp) + return -ENOMEM; + + err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key, + info->extack); + if (err) + goto err_free_msg; + + return psp_nl_reply_send(rsp, info); + +err_free_msg: + nlmsg_free(rsp); + return err; +} diff --git a/net/psp/psp_sock.c b/net/psp/psp_sock.c new file mode 100644 index 000000000000..afa966c6b69d --- /dev/null +++ b/net/psp/psp_sock.c @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include <linux/file.h> +#include <linux/net.h> +#include <linux/rcupdate.h> +#include <linux/tcp.h> + +#include <net/ip.h> +#include <net/psp.h> +#include "psp.h" + +struct psp_dev *psp_dev_get_for_sock(struct sock *sk) +{ + struct dst_entry *dst; + struct psp_dev *psd; + + dst = sk_dst_get(sk); + if (!dst) + return NULL; + + rcu_read_lock(); + psd = rcu_dereference(dst->dev->psp_dev); + if (psd && !psp_dev_tryget(psd)) + psd = NULL; + rcu_read_unlock(); + + dst_release(dst); + + return psd; +} + +static struct sk_buff * +psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb) +{ + struct psp_assoc *pas; + bool good; + + rcu_read_lock(); + pas = psp_skb_get_assoc_rcu(skb); + good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd; + rcu_read_unlock(); + if (!good) { + kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT); + return NULL; + } + + return skb; +} + +struct psp_assoc *psp_assoc_create(struct psp_dev *psd) +{ + struct psp_assoc *pas; + + lockdep_assert_held(&psd->lock); + + pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc), + GFP_KERNEL_ACCOUNT); + if (!pas) + return NULL; + + pas->psd = psd; + pas->dev_id = psd->id; + pas->generation = psd->generation; + psp_dev_get(psd); + refcount_set(&pas->refcnt, 1); + + list_add_tail(&pas->assocs_list, &psd->active_assocs); + + return pas; +} + +static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas) +{ + struct psp_dev *psd = pas->psd; + size_t sz; + + lockdep_assert_held(&psd->lock); + + sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc); + return kmemdup(pas, sz, GFP_KERNEL); +} + +static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas, + struct netlink_ext_ack *extack) +{ + return psd->ops->tx_key_add(psd, pas, extack); +} + +void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas) +{ + if (pas->tx.spi) + psd->ops->tx_key_del(psd, pas); + list_del(&pas->assocs_list); +} + +static void psp_assoc_free(struct work_struct *work) +{ + struct psp_assoc *pas = container_of(work, struct psp_assoc, work); + struct psp_dev *psd = pas->psd; + + mutex_lock(&psd->lock); + if (psd->ops) + psp_dev_tx_key_del(psd, pas); + mutex_unlock(&psd->lock); + psp_dev_put(psd); + kfree(pas); +} + +static void psp_assoc_free_queue(struct rcu_head *head) +{ + struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu); + + INIT_WORK(&pas->work, psp_assoc_free); + schedule_work(&pas->work); +} + +/** + * psp_assoc_put() - release a reference on a PSP association + * @pas: association to release + */ +void psp_assoc_put(struct psp_assoc *pas) +{ + if (pas && refcount_dec_and_test(&pas->refcnt)) + call_rcu(&pas->rcu, psp_assoc_free_queue); +} + +void psp_sk_assoc_free(struct sock *sk) +{ + struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1); + + rcu_assign_pointer(sk->psp_assoc, NULL); + psp_assoc_put(pas); +} + +int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, + struct psp_key_parsed *key, + struct netlink_ext_ack *extack) +{ + int err; + + memcpy(&pas->rx, key, sizeof(*key)); + + lock_sock(sk); + + if (psp_sk_assoc(sk)) { + NL_SET_ERR_MSG(extack, "Socket already has PSP state"); + err = -EBUSY; + goto exit_unlock; + } + + refcount_inc(&pas->refcnt); + rcu_assign_pointer(sk->psp_assoc, pas); + err = 0; + +exit_unlock: + release_sock(sk); + + return err; +} + +static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas) +{ + struct psp_skb_ext *pse; + struct sk_buff *skb; + + skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) { + pse = skb_ext_find(skb, SKB_EXT_PSP); + if (!psp_pse_matches_pas(pse, pas)) + return -EBUSY; + } + + skb_queue_walk(&sk->sk_receive_queue, skb) { + pse = skb_ext_find(skb, SKB_EXT_PSP); + if (!psp_pse_matches_pas(pse, pas)) + return -EBUSY; + } + return 0; +} + +int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, + u32 version, struct psp_key_parsed *key, + struct netlink_ext_ack *extack) +{ + struct inet_connection_sock *icsk; + struct psp_assoc *pas, *dummy; + int err; + + lock_sock(sk); + + pas = psp_sk_assoc(sk); + if (!pas) { + NL_SET_ERR_MSG(extack, "Socket has no Rx key"); + err = -EINVAL; + goto exit_unlock; + } + if (pas->psd != psd) { + NL_SET_ERR_MSG(extack, "Rx key from different device"); + err = -EINVAL; + goto exit_unlock; + } + if (pas->version != version) { + NL_SET_ERR_MSG(extack, + "PSP version mismatch with existing state"); + err = -EINVAL; + goto exit_unlock; + } + if (pas->tx.spi) { + NL_SET_ERR_MSG(extack, "Tx key already set"); + err = -EBUSY; + goto exit_unlock; + } + + err = psp_sock_recv_queue_check(sk, pas); + if (err) { + NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue"); + goto exit_unlock; + } + + /* Pass a fake association to drivers to make sure they don't + * try to store pointers to it. For re-keying we'll need to + * re-allocate the assoc structures. + */ + dummy = psp_assoc_dummy(pas); + if (!dummy) { + err = -ENOMEM; + goto exit_unlock; + } + + memcpy(&dummy->tx, key, sizeof(*key)); + err = psp_dev_tx_key_add(psd, dummy, extack); + if (err) + goto exit_free_dummy; + + memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc); + memcpy(&pas->tx, key, sizeof(*key)); + + WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit); + tcp_write_collapse_fence(sk); + pas->upgrade_seq = tcp_sk(sk)->rcv_nxt; + + icsk = inet_csk(sk); + icsk->icsk_ext_hdr_len += psp_sk_overhead(sk); + icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie); + +exit_free_dummy: + kfree(dummy); +exit_unlock: + release_sock(sk); + return err; +} + +void psp_assocs_key_rotated(struct psp_dev *psd) +{ + struct psp_assoc *pas, *next; + + /* Mark the stale associations as invalid, they will no longer + * be able to Rx any traffic. + */ + list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list) + pas->generation |= ~PSP_GEN_VALID_MASK; + list_splice_init(&psd->prev_assocs, &psd->stale_assocs); + list_splice_init(&psd->active_assocs, &psd->prev_assocs); + + /* TODO: we should inform the sockets that got shut down */ +} + +void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) +{ + struct psp_assoc *pas = psp_sk_assoc(sk); + + if (pas) + refcount_inc(&pas->refcnt); + rcu_assign_pointer(tw->psp_assoc, pas); + tw->tw_validate_xmit_skb = psp_validate_xmit; +} + +void psp_twsk_assoc_free(struct inet_timewait_sock *tw) +{ + struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1); + + rcu_assign_pointer(tw->psp_assoc, NULL); + psp_assoc_put(pas); +} + +void psp_reply_set_decrypted(struct sk_buff *skb) +{ + struct psp_assoc *pas; + + rcu_read_lock(); + pas = psp_sk_get_assoc_rcu(skb->sk); + if (pas && pas->tx.spi) + skb->decrypted = 1; + rcu_read_unlock(); +} +EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted); |
