diff options
Diffstat (limited to 'net/ipv4/fou.c')
-rw-r--r-- | net/ipv4/fou.c | 234 |
1 files changed, 187 insertions, 47 deletions
diff --git a/net/ipv4/fou.c b/net/ipv4/fou.c index ff069f6597ac..af150b43b214 100644 --- a/net/ipv4/fou.c +++ b/net/ipv4/fou.c @@ -16,14 +16,12 @@ #include <uapi/linux/fou.h> #include <uapi/linux/genetlink.h> -static DEFINE_SPINLOCK(fou_lock); -static LIST_HEAD(fou_list); - struct fou { struct socket *sock; u8 protocol; u8 flags; - u16 port; + __be16 port; + u16 type; struct udp_offload udp_offloads; struct list_head list; }; @@ -37,6 +35,13 @@ struct fou_cfg { struct udp_port_cfg udp_config; }; +static unsigned int fou_net_id; + +struct fou_net { + struct list_head fou_list; + struct mutex fou_lock; +}; + static inline struct fou *fou_from_sock(struct sock *sk) { return sk->sk_user_data; @@ -387,20 +392,21 @@ out_unlock: return err; } -static int fou_add_to_port_list(struct fou *fou) +static int fou_add_to_port_list(struct net *net, struct fou *fou) { + struct fou_net *fn = net_generic(net, fou_net_id); struct fou *fout; - spin_lock(&fou_lock); - list_for_each_entry(fout, &fou_list, list) { + mutex_lock(&fn->fou_lock); + list_for_each_entry(fout, &fn->fou_list, list) { if (fou->port == fout->port) { - spin_unlock(&fou_lock); + mutex_unlock(&fn->fou_lock); return -EALREADY; } } - list_add(&fou->list, &fou_list); - spin_unlock(&fou_lock); + list_add(&fou->list, &fn->fou_list); + mutex_unlock(&fn->fou_lock); return 0; } @@ -410,14 +416,10 @@ static void fou_release(struct fou *fou) struct socket *sock = fou->sock; struct sock *sk = sock->sk; - udp_del_offload(&fou->udp_offloads); - + if (sk->sk_family == AF_INET) + udp_del_offload(&fou->udp_offloads); list_del(&fou->list); - - /* Remove hooks into tunnel socket */ - sk->sk_user_data = NULL; - - sock_release(sock); + udp_tunnel_sock_release(sock); kfree(fou); } @@ -447,10 +449,10 @@ static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg) static int fou_create(struct net *net, struct fou_cfg *cfg, struct socket **sockp) { - struct fou *fou = NULL; - int err; struct socket *sock = NULL; + struct fou *fou = NULL; struct sock *sk; + int err; /* Open UDP socket */ err = udp_sock_create(net, &cfg->udp_config, &sock); @@ -486,6 +488,8 @@ static int fou_create(struct net *net, struct fou_cfg *cfg, goto error; } + fou->type = cfg->type; + udp_sk(sk)->encap_type = 1; udp_encap_enable(); @@ -502,7 +506,7 @@ static int fou_create(struct net *net, struct fou_cfg *cfg, goto error; } - err = fou_add_to_port_list(fou); + err = fou_add_to_port_list(net, fou); if (err) goto error; @@ -514,27 +518,27 @@ static int fou_create(struct net *net, struct fou_cfg *cfg, error: kfree(fou); if (sock) - sock_release(sock); + udp_tunnel_sock_release(sock); return err; } static int fou_destroy(struct net *net, struct fou_cfg *cfg) { - struct fou *fou; - u16 port = cfg->udp_config.local_udp_port; + struct fou_net *fn = net_generic(net, fou_net_id); + __be16 port = cfg->udp_config.local_udp_port; int err = -EINVAL; + struct fou *fou; - spin_lock(&fou_lock); - list_for_each_entry(fou, &fou_list, list) { + mutex_lock(&fn->fou_lock); + list_for_each_entry(fou, &fn->fou_list, list) { if (fou->port == port) { - udp_del_offload(&fou->udp_offloads); fou_release(fou); err = 0; break; } } - spin_unlock(&fou_lock); + mutex_unlock(&fn->fou_lock); return err; } @@ -573,7 +577,7 @@ static int parse_nl_config(struct genl_info *info, } if (info->attrs[FOU_ATTR_PORT]) { - u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]); + __be16 port = nla_get_be16(info->attrs[FOU_ATTR_PORT]); cfg->udp_config.local_udp_port = port; } @@ -592,6 +596,7 @@ static int parse_nl_config(struct genl_info *info, static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info) { + struct net *net = genl_info_net(info); struct fou_cfg cfg; int err; @@ -599,16 +604,120 @@ static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info) if (err) return err; - return fou_create(&init_net, &cfg, NULL); + return fou_create(net, &cfg, NULL); } static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info) { + struct net *net = genl_info_net(info); + struct fou_cfg cfg; + int err; + + err = parse_nl_config(info, &cfg); + if (err) + return err; + + return fou_destroy(net, &cfg); +} + +static int fou_fill_info(struct fou *fou, struct sk_buff *msg) +{ + if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) || + nla_put_be16(msg, FOU_ATTR_PORT, fou->port) || + nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) || + nla_put_u8(msg, FOU_ATTR_TYPE, fou->type)) + return -1; + + if (fou->flags & FOU_F_REMCSUM_NOPARTIAL) + if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL)) + return -1; + return 0; +} + +static int fou_dump_info(struct fou *fou, u32 portid, u32 seq, + u32 flags, struct sk_buff *skb, u8 cmd) +{ + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &fou_nl_family, flags, cmd); + if (!hdr) + return -ENOMEM; + + if (fou_fill_info(fou, skb) < 0) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct fou_net *fn = net_generic(net, fou_net_id); + struct sk_buff *msg; struct fou_cfg cfg; + struct fou *fout; + __be16 port; + int ret; + + ret = parse_nl_config(info, &cfg); + if (ret) + return ret; + port = cfg.udp_config.local_udp_port; + if (port == 0) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = -ESRCH; + mutex_lock(&fn->fou_lock); + list_for_each_entry(fout, &fn->fou_list, list) { + if (port == fout->port) { + ret = fou_dump_info(fout, info->snd_portid, + info->snd_seq, 0, msg, + info->genlhdr->cmd); + break; + } + } + mutex_unlock(&fn->fou_lock); + if (ret < 0) + goto out_free; - parse_nl_config(info, &cfg); + return genlmsg_reply(msg, info); - return fou_destroy(&init_net, &cfg); +out_free: + nlmsg_free(msg); + return ret; +} + +static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct fou_net *fn = net_generic(net, fou_net_id); + struct fou *fout; + int idx = 0, ret; + + mutex_lock(&fn->fou_lock); + list_for_each_entry(fout, &fn->fou_list, list) { + if (idx++ < cb->args[0]) + continue; + ret = fou_dump_info(fout, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + skb, FOU_CMD_GET); + if (ret) + goto done; + } + mutex_unlock(&fn->fou_lock); + +done: + cb->args[0] = idx; + return skb->len; } static const struct genl_ops fou_nl_ops[] = { @@ -624,6 +733,12 @@ static const struct genl_ops fou_nl_ops[] = { .policy = fou_nl_policy, .flags = GENL_ADMIN_PERM, }, + { + .cmd = FOU_CMD_GET, + .doit = fou_nl_cmd_get_port, + .dumpit = fou_nl_dump, + .policy = fou_nl_policy, + }, }; size_t fou_encap_hlen(struct ip_tunnel_encap *e) @@ -771,12 +886,12 @@ EXPORT_SYMBOL(gue_build_header); #ifdef CONFIG_NET_FOU_IP_TUNNELS -static const struct ip_tunnel_encap_ops __read_mostly fou_iptun_ops = { +static const struct ip_tunnel_encap_ops fou_iptun_ops = { .encap_hlen = fou_encap_hlen, .build_header = fou_build_header, }; -static const struct ip_tunnel_encap_ops __read_mostly gue_iptun_ops = { +static const struct ip_tunnel_encap_ops gue_iptun_ops = { .encap_hlen = gue_encap_hlen, .build_header = gue_build_header, }; @@ -820,38 +935,63 @@ static void ip_tunnel_encap_del_fou_ops(void) #endif +static __net_init int fou_init_net(struct net *net) +{ + struct fou_net *fn = net_generic(net, fou_net_id); + + INIT_LIST_HEAD(&fn->fou_list); + mutex_init(&fn->fou_lock); + return 0; +} + +static __net_exit void fou_exit_net(struct net *net) +{ + struct fou_net *fn = net_generic(net, fou_net_id); + struct fou *fou, *next; + + /* Close all the FOU sockets */ + mutex_lock(&fn->fou_lock); + list_for_each_entry_safe(fou, next, &fn->fou_list, list) + fou_release(fou); + mutex_unlock(&fn->fou_lock); +} + +static struct pernet_operations fou_net_ops = { + .init = fou_init_net, + .exit = fou_exit_net, + .id = &fou_net_id, + .size = sizeof(struct fou_net), +}; + static int __init fou_init(void) { int ret; + ret = register_pernet_device(&fou_net_ops); + if (ret) + goto exit; + ret = genl_register_family_with_ops(&fou_nl_family, fou_nl_ops); - if (ret < 0) - goto exit; + goto unregister; ret = ip_tunnel_encap_add_fou_ops(); - if (ret < 0) - genl_unregister_family(&fou_nl_family); + if (ret == 0) + return 0; + genl_unregister_family(&fou_nl_family); +unregister: + unregister_pernet_device(&fou_net_ops); exit: return ret; } static void __exit fou_fini(void) { - struct fou *fou, *next; - ip_tunnel_encap_del_fou_ops(); - genl_unregister_family(&fou_nl_family); - - /* Close all the FOU sockets */ - - spin_lock(&fou_lock); - list_for_each_entry_safe(fou, next, &fou_list, list) - fou_release(fou); - spin_unlock(&fou_lock); + unregister_pernet_device(&fou_net_ops); } module_init(fou_init); |