diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index ca1041c88c6883ae942c1c9aea1a9cbc7e312efb..1f80eac5d6ae05f1bf50e6d6820e02ff1fe0a7d0 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -341,6 +341,8 @@ static bool vhost_worker(void *data)
 
 	node = llist_del_all(&worker->work_list);
 	if (node) {
+		__set_current_state(TASK_RUNNING);
+
 		node = llist_reverse_order(node);
 		/* make sure flag is seen after deletion */
 		smp_wmb();
diff --git a/kernel/vhost_task.c b/kernel/vhost_task.c
index f80d5c51ae67106c0bd72eb19efd39f4b834e147..da35e5b7f04738fc590a426136944968a905db1a 100644
--- a/kernel/vhost_task.c
+++ b/kernel/vhost_task.c
@@ -28,10 +28,6 @@ static int vhost_task_fn(void *data)
 	for (;;) {
 		bool did_work;
 
-		/* mb paired w/ vhost_task_stop */
-		if (test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags))
-			break;
-
 		if (!dead && signal_pending(current)) {
 			struct ksignal ksig;
 			/*
@@ -48,11 +44,17 @@ static int vhost_task_fn(void *data)
 				clear_thread_flag(TIF_SIGPENDING);
 		}
 
+		/* mb paired w/ vhost_task_stop */
+		set_current_state(TASK_INTERRUPTIBLE);
+
+		if (test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
+			__set_current_state(TASK_RUNNING);
+			break;
+		}
+
 		did_work = vtsk->fn(vtsk->data);
-		if (!did_work) {
-			set_current_state(TASK_INTERRUPTIBLE);
+		if (!did_work)
 			schedule();
-		}
 	}
 
 	complete(&vtsk->exited);