diff --git a/drivers/gpu/drm/drm_dp_mst_topology.c b/drivers/gpu/drm/drm_dp_mst_topology.c
index 35c848535d097a6a7b7635ec1c8afaf4af00446f..3d421a60d13789db59a51260aee6294e6386769d 100644
--- a/drivers/gpu/drm/drm_dp_mst_topology.c
+++ b/drivers/gpu/drm/drm_dp_mst_topology.c
@@ -2233,6 +2233,12 @@ int drm_dp_mst_topology_mgr_set_mst(struct drm_dp_mst_topology_mgr *mgr, bool ms
 	if (mst_state) {
 		WARN_ON(mgr->mst_primary);
 
+		if (mgr->cbs->pre_enable) {
+			ret = mgr->cbs->pre_enable(mgr);
+			if (ret)
+				goto out_unlock;
+		}
+
 		/* get dpcd info */
 		ret = drm_dp_dpcd_read(mgr->aux, DP_DPCD_REV, mgr->dpcd, DP_RECEIVER_CAP_SIZE);
 		if (ret != DP_RECEIVER_CAP_SIZE) {
@@ -2277,13 +2283,16 @@ int drm_dp_mst_topology_mgr_set_mst(struct drm_dp_mst_topology_mgr *mgr, bool ms
 		mgr->mst_state = true;
 		ret = 0;
 	} else {
+		if (mgr->cbs->post_disable)
+			mgr->cbs->post_disable(mgr);
+
 		/* disable MST on the device */
 		mstb = mgr->mst_primary;
 		mgr->mst_primary = NULL;
 		/* this can fail if the device is gone */
 		drm_dp_dpcd_writeb(mgr->aux, DP_MSTM_CTRL, 0);
-		ret = 0;
-		memset(mgr->payloads, 0, mgr->max_payloads * sizeof(struct drm_dp_payload));
+		memset(mgr->payloads, 0,
+		       mgr->max_payloads * sizeof(struct drm_dp_payload));
 		mgr->payload_mask = 0;
 		set_bit(0, &mgr->payload_mask);
 		mgr->vcpi_mask = 0;
@@ -2291,6 +2300,9 @@ int drm_dp_mst_topology_mgr_set_mst(struct drm_dp_mst_topology_mgr *mgr, bool ms
 	}
 
 out_unlock:
+	if (ret && mst_state && mgr->cbs->post_disable)
+		mgr->cbs->post_disable(mgr);
+
 	mutex_unlock(&mgr->lock);
 	if (mstb)
 		drm_dp_put_mst_branch_device(mstb);
diff --git a/drivers/gpu/drm/nouveau/dispnv50/disp.c b/drivers/gpu/drm/nouveau/dispnv50/disp.c
index 05a58cb36ceec413abc6ba9625c8bfe20b23d36e..35c6efbdd8669dc0af9803a4020f82ab919b2fd2 100644
--- a/drivers/gpu/drm/nouveau/dispnv50/disp.c
+++ b/drivers/gpu/drm/nouveau/dispnv50/disp.c
@@ -1128,12 +1128,62 @@ nv50_mstm_add_connector(struct drm_dp_mst_topology_mgr *mgr,
 	return &mstc->connector;
 }
 
+static int
+nv50_mstm_enable(struct nv50_mstm *mstm, int state)
+{
+	struct nouveau_encoder *outp = mstm->outp;
+	struct {
+		struct nv50_disp_mthd_v1 base;
+		struct nv50_disp_sor_dp_mst_link_v0 mst;
+	} args = {
+		.base.version = 1,
+		.base.method = NV50_DISP_MTHD_V1_SOR_DP_MST_LINK,
+		.base.hasht = outp->dcb->hasht,
+		.base.hashm = outp->dcb->hashm,
+		.mst.state = state,
+	};
+	struct nouveau_drm *drm = nouveau_drm(outp->base.base.dev);
+	struct nvif_object *disp = &drm->display->disp.object;
+	int ret;
+
+	/* Even if we're enabling MST, start with disabling the branching unit
+	 * to clear any sink-side MST topology state that wasn't set by us
+	 */
+	ret = drm_dp_dpcd_writeb(mstm->mgr.aux, DP_MSTM_CTRL, 0);
+	if (ret < 0)
+		return ret;
+
+	if (state) {
+		/* Now, start initializing */
+		ret = drm_dp_dpcd_writeb(mstm->mgr.aux, DP_MSTM_CTRL,
+					 DP_MST_EN);
+		if (ret < 0)
+			return ret;
+	}
+
+	return nvif_mthd(disp, 0, &args, sizeof(args));
+}
+
+static int
+nv50_mstm_pre_enable(struct drm_dp_mst_topology_mgr *mgr)
+{
+	return nv50_mstm_enable(nv50_mstm(mgr), true);
+}
+
+static void
+nv50_mstm_post_disable(struct drm_dp_mst_topology_mgr *mgr)
+{
+	nv50_mstm_enable(nv50_mstm(mgr), false);
+}
+
 static const struct drm_dp_mst_topology_cbs
 nv50_mstm = {
 	.add_connector = nv50_mstm_add_connector,
 	.register_connector = nv50_mstm_register_connector,
 	.destroy_connector = nv50_mstm_destroy_connector,
 	.hotplug = nv50_mstm_hotplug,
+	.pre_enable = nv50_mstm_pre_enable,
+	.post_disable = nv50_mstm_post_disable,
 };
 
 void
@@ -1169,42 +1219,6 @@ nv50_mstm_remove(struct nv50_mstm *mstm)
 		drm_dp_mst_topology_mgr_set_mst(&mstm->mgr, false);
 }
 
-static int
-nv50_mstm_enable(struct nv50_mstm *mstm, int state)
-{
-	struct nouveau_encoder *outp = mstm->outp;
-	struct {
-		struct nv50_disp_mthd_v1 base;
-		struct nv50_disp_sor_dp_mst_link_v0 mst;
-	} args = {
-		.base.version = 1,
-		.base.method = NV50_DISP_MTHD_V1_SOR_DP_MST_LINK,
-		.base.hasht = outp->dcb->hasht,
-		.base.hashm = outp->dcb->hashm,
-		.mst.state = state,
-	};
-	struct nouveau_drm *drm = nouveau_drm(outp->base.base.dev);
-	struct nvif_object *disp = &drm->display->disp.object;
-	int ret;
-
-	/* Even if we're enabling MST, start with disabling the branching unit
-	 * to clear any sink-side MST topology state that wasn't set by us
-	 */
-	ret = drm_dp_dpcd_writeb(mstm->mgr.aux, DP_MSTM_CTRL, 0);
-	if (ret < 0)
-		return ret;
-
-	if (state) {
-		/* Now, start initializing */
-		ret = drm_dp_dpcd_writeb(mstm->mgr.aux, DP_MSTM_CTRL,
-					 DP_MST_EN);
-		if (ret < 0)
-			return ret;
-	}
-
-	return nvif_mthd(disp, 0, &args, sizeof(args));
-}
-
 int
 nv50_mstm_detect(struct nv50_mstm *mstm, u8 dpcd[8], int allow)
 {
@@ -1216,9 +1230,7 @@ nv50_mstm_detect(struct nv50_mstm *mstm, u8 dpcd[8], int allow)
 	if (!mstm)
 		return 0;
 
-	mutex_lock(&mstm->mgr.lock);
-
-	old_state = mstm->mgr.mst_state;
+	old_state = READ_ONCE(mstm->mgr.mst_state);
 	new_state = old_state;
 	aux = mstm->mgr.aux;
 
@@ -1232,7 +1244,7 @@ nv50_mstm_detect(struct nv50_mstm *mstm, u8 dpcd[8], int allow)
 	} else if (dpcd[0] >= 0x12) {
 		ret = drm_dp_dpcd_readb(aux, DP_MSTM_CAP, &dpcd[1]);
 		if (ret < 0)
-			goto probe_error;
+			return ret;
 
 		if (!(dpcd[1] & DP_MST_CAP))
 			dpcd[0] = 0x11;
@@ -1240,26 +1252,10 @@ nv50_mstm_detect(struct nv50_mstm *mstm, u8 dpcd[8], int allow)
 			new_state = allow;
 	}
 
-	if (new_state == old_state) {
-		mutex_unlock(&mstm->mgr.lock);
+	if (new_state == old_state)
 		return new_state;
-	}
-
-	ret = nv50_mstm_enable(mstm, new_state);
-	if (ret)
-		goto probe_error;
 
-	mutex_unlock(&mstm->mgr.lock);
-
-	ret = drm_dp_mst_topology_mgr_set_mst(&mstm->mgr, new_state);
-	if (ret)
-		return nv50_mstm_enable(mstm, 0);
-
-	return new_state;
-
-probe_error:
-	mutex_unlock(&mstm->mgr.lock);
-	return ret;
+	return drm_dp_mst_topology_mgr_set_mst(&mstm->mgr, new_state);
 }
 
 static void
diff --git a/include/drm/drm_dp_mst_helper.h b/include/drm/drm_dp_mst_helper.h
index c252e42677e5f7a3f1d7d7dae399c9733d55694a..46292ed5f50ed7ff9c6e538580bf39cf3b06586c 100644
--- a/include/drm/drm_dp_mst_helper.h
+++ b/include/drm/drm_dp_mst_helper.h
@@ -326,11 +326,15 @@ struct drm_dp_sideband_msg_tx {
 struct drm_dp_mst_topology_mgr;
 struct drm_dp_mst_topology_cbs {
 	/* create a connector for a port */
-	struct drm_connector *(*add_connector)(struct drm_dp_mst_topology_mgr *mgr, struct drm_dp_mst_port *port, const char *path);
+	struct drm_connector *(*add_connector)(struct drm_dp_mst_topology_mgr *mgr,
+					       struct drm_dp_mst_port *port,
+					       const char *path);
 	void (*register_connector)(struct drm_connector *connector);
 	void (*destroy_connector)(struct drm_dp_mst_topology_mgr *mgr,
 				  struct drm_connector *connector);
 	void (*hotplug)(struct drm_dp_mst_topology_mgr *mgr);
+	int (*pre_enable)(struct drm_dp_mst_topology_mgr *mgr);
+	void (*post_disable)(struct drm_dp_mst_topology_mgr *mgr);
 
 };