diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r-- | net/netlink/af_netlink.c | 161 |
1 files changed, 126 insertions, 35 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 641cfbc278d8..5681ce3aebca 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -62,6 +62,7 @@ #include <net/netlink.h> #define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) +#define NLGRPLONGS(x) (NLGRPSZ(x)/sizeof(unsigned long)) struct netlink_sock { /* struct sock has to be the first member of netlink_sock */ @@ -314,10 +315,12 @@ netlink_update_listeners(struct sock *sk) unsigned long mask; unsigned int i; - for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) { + for (i = 0; i < NLGRPLONGS(tbl->groups); i++) { mask = 0; - sk_for_each_bound(sk, node, &tbl->mc_list) - mask |= nlk_sk(sk)->groups[i]; + sk_for_each_bound(sk, node, &tbl->mc_list) { + if (i < NLGRPLONGS(nlk_sk(sk)->ngroups)) + mask |= nlk_sk(sk)->groups[i]; + } tbl->listeners[i] = mask; } /* this function is only called with the netlink table "grabbed", which @@ -555,26 +558,37 @@ netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions) nlk->subscriptions = subscriptions; } -static int netlink_alloc_groups(struct sock *sk) +static int netlink_realloc_groups(struct sock *sk) { struct netlink_sock *nlk = nlk_sk(sk); unsigned int groups; + unsigned long *new_groups; int err = 0; - netlink_lock_table(); + netlink_table_grab(); + groups = nl_table[sk->sk_protocol].groups; - if (!nl_table[sk->sk_protocol].registered) + if (!nl_table[sk->sk_protocol].registered) { err = -ENOENT; - netlink_unlock_table(); + goto out_unlock; + } - if (err) - return err; + if (nlk->ngroups >= groups) + goto out_unlock; - nlk->groups = kzalloc(NLGRPSZ(groups), GFP_KERNEL); - if (nlk->groups == NULL) - return -ENOMEM; + new_groups = krealloc(nlk->groups, NLGRPSZ(groups), GFP_ATOMIC); + if (new_groups == NULL) { + err = -ENOMEM; + goto out_unlock; + } + memset((char*)new_groups + NLGRPSZ(nlk->ngroups), 0, + NLGRPSZ(groups) - NLGRPSZ(nlk->ngroups)); + + nlk->groups = new_groups; nlk->ngroups = groups; - return 0; + out_unlock: + netlink_table_ungrab(); + return err; } static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len) @@ -591,11 +605,9 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len if (nladdr->nl_groups) { if (!netlink_capable(sock, NL_NONROOT_RECV)) return -EPERM; - if (nlk->groups == NULL) { - err = netlink_alloc_groups(sk); - if (err) - return err; - } + err = netlink_realloc_groups(sk); + if (err) + return err; } if (nlk->pid) { @@ -839,10 +851,18 @@ retry: int netlink_has_listeners(struct sock *sk, unsigned int group) { int res = 0; + unsigned long *listeners; BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET)); + + rcu_read_lock(); + listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners); + if (group - 1 < nl_table[sk->sk_protocol].groups) - res = test_bit(group - 1, nl_table[sk->sk_protocol].listeners); + res = test_bit(group - 1, listeners); + + rcu_read_unlock(); + return res; } EXPORT_SYMBOL_GPL(netlink_has_listeners); @@ -1007,6 +1027,23 @@ void netlink_set_err(struct sock *ssk, u32 pid, u32 group, int code) read_unlock(&nl_table_lock); } +/* must be called with netlink table grabbed */ +static void netlink_update_socket_mc(struct netlink_sock *nlk, + unsigned int group, + int is_new) +{ + int old, new = !!is_new, subscriptions; + + old = test_bit(group - 1, nlk->groups); + subscriptions = nlk->subscriptions - old + new; + if (new) + __set_bit(group - 1, nlk->groups); + else + __clear_bit(group - 1, nlk->groups); + netlink_update_subscriptions(&nlk->sk, subscriptions); + netlink_update_listeners(&nlk->sk); +} + static int netlink_setsockopt(struct socket *sock, int level, int optname, char __user *optval, int optlen) { @@ -1032,27 +1069,16 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, break; case NETLINK_ADD_MEMBERSHIP: case NETLINK_DROP_MEMBERSHIP: { - unsigned int subscriptions; - int old, new = optname == NETLINK_ADD_MEMBERSHIP ? 1 : 0; - if (!netlink_capable(sock, NL_NONROOT_RECV)) return -EPERM; - if (nlk->groups == NULL) { - err = netlink_alloc_groups(sk); - if (err) - return err; - } + err = netlink_realloc_groups(sk); + if (err) + return err; if (!val || val - 1 >= nlk->ngroups) return -EINVAL; netlink_table_grab(); - old = test_bit(val - 1, nlk->groups); - subscriptions = nlk->subscriptions - old + new; - if (new) - __set_bit(val - 1, nlk->groups); - else - __clear_bit(val - 1, nlk->groups); - netlink_update_subscriptions(sk, subscriptions); - netlink_update_listeners(sk); + netlink_update_socket_mc(nlk, val, + optname == NETLINK_ADD_MEMBERSHIP); netlink_table_ungrab(); err = 0; break; @@ -1328,6 +1354,71 @@ out_sock_release: return NULL; } +/** + * netlink_change_ngroups - change number of multicast groups + * + * This changes the number of multicast groups that are available + * on a certain netlink family. Note that it is not possible to + * change the number of groups to below 32. Also note that it does + * not implicitly call netlink_clear_multicast_users() when the + * number of groups is reduced. + * + * @sk: The kernel netlink socket, as returned by netlink_kernel_create(). + * @groups: The new number of groups. + */ +int netlink_change_ngroups(struct sock *sk, unsigned int groups) +{ + unsigned long *listeners, *old = NULL; + struct netlink_table *tbl = &nl_table[sk->sk_protocol]; + int err = 0; + + if (groups < 32) + groups = 32; + + netlink_table_grab(); + if (NLGRPSZ(tbl->groups) < NLGRPSZ(groups)) { + listeners = kzalloc(NLGRPSZ(groups), GFP_ATOMIC); + if (!listeners) { + err = -ENOMEM; + goto out_ungrab; + } + old = tbl->listeners; + memcpy(listeners, old, NLGRPSZ(tbl->groups)); + rcu_assign_pointer(tbl->listeners, listeners); + } + tbl->groups = groups; + + out_ungrab: + netlink_table_ungrab(); + synchronize_rcu(); + kfree(old); + return err; +} +EXPORT_SYMBOL(netlink_change_ngroups); + +/** + * netlink_clear_multicast_users - kick off multicast listeners + * + * This function removes all listeners from the given group. + * @ksk: The kernel netlink socket, as returned by + * netlink_kernel_create(). + * @group: The multicast group to clear. + */ +void netlink_clear_multicast_users(struct sock *ksk, unsigned int group) +{ + struct sock *sk; + struct hlist_node *node; + struct netlink_table *tbl = &nl_table[ksk->sk_protocol]; + + netlink_table_grab(); + + sk_for_each_bound(sk, node, &tbl->mc_list) + netlink_update_socket_mc(nlk_sk(sk), group, 0); + + netlink_table_ungrab(); +} +EXPORT_SYMBOL(netlink_clear_multicast_users); + void netlink_set_nonroot(int protocol, unsigned int flags) { if ((unsigned int)protocol < MAX_LINKS) |