diff --git a/net/switchdev/switchdev.c b/net/switchdev/switchdev.c
index 6a00c390547b87d89a6d20c8da436584883fe834..28d2ccfe109cb12a9e76ddf5bc1f8888e0a0fa3c 100644
--- a/net/switchdev/switchdev.c
+++ b/net/switchdev/switchdev.c
@@ -564,7 +564,7 @@ static int __switchdev_handle_port_obj_add(struct net_device *dev,
 				      struct netlink_ext_ack *extack))
 {
 	struct switchdev_notifier_info *info = &port_obj_info->info;
-	struct net_device *br, *lower_dev;
+	struct net_device *br, *lower_dev, *switchdev;
 	struct netlink_ext_ack *extack;
 	struct list_head *iter;
 	int err = -EOPNOTSUPP;
@@ -614,7 +614,11 @@ static int __switchdev_handle_port_obj_add(struct net_device *dev,
 	if (!br || !netif_is_bridge_master(br))
 		return err;
 
-	if (!switchdev_lower_dev_find(br, check_cb, foreign_dev_check_cb))
+	switchdev = switchdev_lower_dev_find(br, check_cb, foreign_dev_check_cb);
+	if (!switchdev)
+		return err;
+
+	if (!foreign_dev_check_cb(switchdev, dev))
 		return err;
 
 	return __switchdev_handle_port_obj_add(br, port_obj_info, check_cb,
@@ -674,7 +678,7 @@ static int __switchdev_handle_port_obj_del(struct net_device *dev,
 				      const struct switchdev_obj *obj))
 {
 	struct switchdev_notifier_info *info = &port_obj_info->info;
-	struct net_device *br, *lower_dev;
+	struct net_device *br, *lower_dev, *switchdev;
 	struct list_head *iter;
 	int err = -EOPNOTSUPP;
 
@@ -721,7 +725,11 @@ static int __switchdev_handle_port_obj_del(struct net_device *dev,
 	if (!br || !netif_is_bridge_master(br))
 		return err;
 
-	if (!switchdev_lower_dev_find(br, check_cb, foreign_dev_check_cb))
+	switchdev = switchdev_lower_dev_find(br, check_cb, foreign_dev_check_cb);
+	if (!switchdev)
+		return err;
+
+	if (!foreign_dev_check_cb(switchdev, dev))
 		return err;
 
 	return __switchdev_handle_port_obj_del(br, port_obj_info, check_cb,