diff --git a/net/mctp/route.c b/net/mctp/route.c
index f9a80b82dc511dbcde9e116838702259c7765759..ce10ba7ae83933371df0fbf78654964d9458edbe 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -147,6 +147,7 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
 	key->valid = true;
 	spin_lock_init(&key->lock);
 	refcount_set(&key->refs, 1);
+	sock_hold(key->sk);
 
 	return key;
 }
@@ -165,6 +166,7 @@ void mctp_key_unref(struct mctp_sk_key *key)
 	mctp_dev_release_key(key->dev, key);
 	spin_unlock_irqrestore(&key->lock, flags);
 
+	sock_put(key->sk);
 	kfree(key);
 }
 
@@ -419,14 +421,14 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 			 * this function.
 			 */
 			rc = mctp_key_add(key, msk);
-			if (rc) {
-				kfree(key);
-			} else {
+			if (!rc)
 				trace_mctp_key_acquire(key);
 
-				/* we don't need to release key->lock on exit */
-				mctp_key_unref(key);
-			}
+			/* we don't need to release key->lock on exit, so
+			 * clean up here and suppress the unlock via
+			 * setting to NULL
+			 */
+			mctp_key_unref(key);
 			key = NULL;
 
 		} else {