From 8798481b667fa7c9bbd5aa843bf1557ada699964 Mon Sep 17 00:00:00 2001 From: Pedro Tammela Date: Fri, 28 Jul 2023 12:35:33 -0300 Subject: net/sched: wrap open coded Qdics class filter counter The 'filter_cnt' counter is used to control a Qdisc class lifetime. Each filter referecing this class by its id will eventually increment/decrement this counter in their respective 'add/update/delete' routines. As these operations are always serialized under rtnl lock, we don't need an atomic type like 'refcount_t'. It also means that we lose the overflow/underflow checks already present in refcount_t, which are valuable to hunt down bugs where the unsigned counter wraps around as it aids automated tools like syzkaller to scream in such situations. Wrap the open coded increment/decrement into helper functions and add overflow checks to the operations. Acked-by: Jamal Hadi Salim Signed-off-by: Pedro Tammela Reviewed-by: Simon Horman Signed-off-by: Paolo Abeni --- include/net/sch_generic.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) (limited to 'include') diff --git a/include/net/sch_generic.h b/include/net/sch_generic.h index 15be2d96b06d..f232512505f8 100644 --- a/include/net/sch_generic.h +++ b/include/net/sch_generic.h @@ -599,6 +599,7 @@ get_default_qdisc_ops(const struct net_device *dev, int ntx) struct Qdisc_class_common { u32 classid; + unsigned int filter_cnt; struct hlist_node hnode; }; @@ -633,6 +634,31 @@ qdisc_class_find(const struct Qdisc_class_hash *hash, u32 id) return NULL; } +static inline bool qdisc_class_in_use(const struct Qdisc_class_common *cl) +{ + return cl->filter_cnt > 0; +} + +static inline void qdisc_class_get(struct Qdisc_class_common *cl) +{ + unsigned int res; + + if (check_add_overflow(cl->filter_cnt, 1, &res)) + WARN(1, "Qdisc class overflow"); + + cl->filter_cnt = res; +} + +static inline void qdisc_class_put(struct Qdisc_class_common *cl) +{ + unsigned int res; + + if (check_sub_overflow(cl->filter_cnt, 1, &res)) + WARN(1, "Qdisc class underflow"); + + cl->filter_cnt = res; +} + static inline int tc_classid_to_hwtc(struct net_device *dev, u32 classid) { u32 hwtc = TC_H_MIN(classid) - TC_H_MIN_PRIORITY; -- cgit v1.2.3