diff --git a/include/net/act_api.h b/include/net/act_api.h
index 3a1a72990fceb7744725b65dd65a0827ad1d9636..4be8b0daedf030c18a3e2202b6582b1a6444065b 100644
--- a/include/net/act_api.h
+++ b/include/net/act_api.h
@@ -78,6 +78,8 @@ static inline void tcf_tm_dump(struct tcf_t *dtm, const struct tcf_t *stm)
 #define ACT_P_CREATED 1
 #define ACT_P_DELETED 1
 
+typedef void (*tc_action_priv_destructor)(void *priv);
+
 struct tc_action_ops {
 	struct list_head head;
 	char    kind[IFNAMSIZ];
@@ -101,6 +103,9 @@ struct tc_action_ops {
 	size_t  (*get_fill_size)(const struct tc_action *act);
 	struct net_device *(*get_dev)(const struct tc_action *a);
 	void	(*put_dev)(struct net_device *dev);
+	struct psample_group *
+	(*get_psample_group)(const struct tc_action *a,
+			     tc_action_priv_destructor *destructor);
 };
 
 struct tc_action_net {
diff --git a/include/net/psample.h b/include/net/psample.h
index 6b578ce69cd8a2b78d0fcaed3005f435380c4823..68ae16bb0a4a85986aeea06b2d11cc13436a32e0 100644
--- a/include/net/psample.h
+++ b/include/net/psample.h
@@ -15,6 +15,7 @@ struct psample_group {
 };
 
 struct psample_group *psample_group_get(struct net *net, u32 group_num);
+void psample_group_take(struct psample_group *group);
 void psample_group_put(struct psample_group *group);
 
 #if IS_ENABLED(CONFIG_PSAMPLE)
diff --git a/include/net/tc_act/tc_sample.h b/include/net/tc_act/tc_sample.h
index b4fce0fae645690207be3b4cd289cf27f8578f03..b5d76305e8544cce7facafadee59534faae38199 100644
--- a/include/net/tc_act/tc_sample.h
+++ b/include/net/tc_act/tc_sample.h
@@ -41,10 +41,4 @@ static inline int tcf_sample_trunc_size(const struct tc_action *a)
 	return to_sample(a)->trunc_size;
 }
 
-static inline struct psample_group *
-tcf_sample_psample_group(const struct tc_action *a)
-{
-	return rcu_dereference_rtnl(to_sample(a)->psample_group);
-}
-
 #endif /* __NET_TC_SAMPLE_H */
diff --git a/net/psample/psample.c b/net/psample/psample.c
index 66e4b61a350d5472776286e415192fba140d8a36..a6ceb0533b5bbb6c03c467f411361dad343e1257 100644
--- a/net/psample/psample.c
+++ b/net/psample/psample.c
@@ -73,7 +73,7 @@ static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
 	int idx = 0;
 	int err;
 
-	spin_lock(&psample_groups_lock);
+	spin_lock_bh(&psample_groups_lock);
 	list_for_each_entry(group, &psample_groups_list, list) {
 		if (!net_eq(group->net, sock_net(msg->sk)))
 			continue;
@@ -89,7 +89,7 @@ static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
 		idx++;
 	}
 
-	spin_unlock(&psample_groups_lock);
+	spin_unlock_bh(&psample_groups_lock);
 	cb->args[0] = idx;
 	return msg->len;
 }
@@ -172,7 +172,7 @@ struct psample_group *psample_group_get(struct net *net, u32 group_num)
 {
 	struct psample_group *group;
 
-	spin_lock(&psample_groups_lock);
+	spin_lock_bh(&psample_groups_lock);
 
 	group = psample_group_lookup(net, group_num);
 	if (!group) {
@@ -183,19 +183,27 @@ struct psample_group *psample_group_get(struct net *net, u32 group_num)
 	group->refcount++;
 
 out:
-	spin_unlock(&psample_groups_lock);
+	spin_unlock_bh(&psample_groups_lock);
 	return group;
 }
 EXPORT_SYMBOL_GPL(psample_group_get);
 
+void psample_group_take(struct psample_group *group)
+{
+	spin_lock_bh(&psample_groups_lock);
+	group->refcount++;
+	spin_unlock_bh(&psample_groups_lock);
+}
+EXPORT_SYMBOL_GPL(psample_group_take);
+
 void psample_group_put(struct psample_group *group)
 {
-	spin_lock(&psample_groups_lock);
+	spin_lock_bh(&psample_groups_lock);
 
 	if (--group->refcount == 0)
 		psample_group_destroy(group);
 
-	spin_unlock(&psample_groups_lock);
+	spin_unlock_bh(&psample_groups_lock);
 }
 EXPORT_SYMBOL_GPL(psample_group_put);
 
diff --git a/net/sched/act_sample.c b/net/sched/act_sample.c
index 10229124a9924efda7745e5917a37fd5c6bff057..692c4c9040fd6b0b0322c4b0d7f96ad8dd77d21c 100644
--- a/net/sched/act_sample.c
+++ b/net/sched/act_sample.c
@@ -252,6 +252,32 @@ static int tcf_sample_search(struct net *net, struct tc_action **a, u32 index)
 	return tcf_idr_search(tn, a, index);
 }
 
+static void tcf_psample_group_put(void *priv)
+{
+	struct psample_group *group = priv;
+
+	psample_group_put(group);
+}
+
+static struct psample_group *
+tcf_sample_get_group(const struct tc_action *a,
+		     tc_action_priv_destructor *destructor)
+{
+	struct tcf_sample *s = to_sample(a);
+	struct psample_group *group;
+
+	spin_lock_bh(&s->tcf_lock);
+	group = rcu_dereference_protected(s->psample_group,
+					  lockdep_is_held(&s->tcf_lock));
+	if (group) {
+		psample_group_take(group);
+		*destructor = tcf_psample_group_put;
+	}
+	spin_unlock_bh(&s->tcf_lock);
+
+	return group;
+}
+
 static struct tc_action_ops act_sample_ops = {
 	.kind	  = "sample",
 	.id	  = TCA_ID_SAMPLE,
@@ -262,6 +288,7 @@ static struct tc_action_ops act_sample_ops = {
 	.cleanup  = tcf_sample_cleanup,
 	.walk	  = tcf_sample_walker,
 	.lookup	  = tcf_sample_search,
+	.get_psample_group = tcf_sample_get_group,
 	.size	  = sizeof(struct tcf_sample),
 };
 
diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c
index c668195379bda46f49451e4e6263fb008a7b0e88..60d44b14750a41c8dd876d1e71d6afa2795ad0dd 100644
--- a/net/sched/cls_api.c
+++ b/net/sched/cls_api.c
@@ -3324,6 +3324,16 @@ static int tcf_tunnel_encap_get_tunnel(struct flow_action_entry *entry,
 	return 0;
 }
 
+static void tcf_sample_get_group(struct flow_action_entry *entry,
+				 const struct tc_action *act)
+{
+#ifdef CONFIG_NET_CLS_ACT
+	entry->sample.psample_group =
+		act->ops->get_psample_group(act, &entry->destructor);
+	entry->destructor_priv = entry->sample.psample_group;
+#endif
+}
+
 int tc_setup_flow_action(struct flow_action *flow_action,
 			 const struct tcf_exts *exts, bool rtnl_held)
 {
@@ -3417,11 +3427,10 @@ int tc_setup_flow_action(struct flow_action *flow_action,
 			entry->mark = tcf_skbedit_mark(act);
 		} else if (is_tcf_sample(act)) {
 			entry->id = FLOW_ACTION_SAMPLE;
-			entry->sample.psample_group =
-				tcf_sample_psample_group(act);
 			entry->sample.trunc_size = tcf_sample_trunc_size(act);
 			entry->sample.truncate = tcf_sample_truncate(act);
 			entry->sample.rate = tcf_sample_rate(act);
+			tcf_sample_get_group(entry, act);
 		} else if (is_tcf_police(act)) {
 			entry->id = FLOW_ACTION_POLICE;
 			entry->police.burst = tcf_police_tcfp_burst(act);