diff --git a/kernel/sched/psi.c b/kernel/sched/psi.c
index 77d53c03a76fd6b62719bb060208ac3c8570b3f0..d71dbc2356ffbfa6eff8a7ae371cf50c94b074f0 100644
--- a/kernel/sched/psi.c
+++ b/kernel/sched/psi.c
@@ -820,20 +820,15 @@ void psi_task_switch(struct task_struct *prev, struct task_struct *next,
 	u64 now = cpu_clock(cpu);
 
 	if (next->pid) {
-		bool identical_state;
-
 		psi_flags_change(next, 0, TSK_ONCPU);
 		/*
-		 * When switching between tasks that have an identical
-		 * runtime state, the cgroup that contains both tasks
-		 * we reach the first common ancestor. Iterate @next's
-		 * ancestors only until we encounter @prev's ONCPU.
+		 * Set TSK_ONCPU on @next's cgroups. If @next shares any
+		 * ancestors with @prev, those will already have @prev's
+		 * TSK_ONCPU bit set, and we can stop the iteration there.
 		 */
-		identical_state = prev->psi_flags == next->psi_flags;
 		iter = NULL;
 		while ((group = iterate_groups(next, &iter))) {
-			if (identical_state &&
-			    per_cpu_ptr(group->pcpu, cpu)->tasks[NR_ONCPU]) {
+			if (per_cpu_ptr(group->pcpu, cpu)->tasks[NR_ONCPU]) {
 				common = group;
 				break;
 			}
@@ -877,10 +872,12 @@ void psi_task_switch(struct task_struct *prev, struct task_struct *next,
 			psi_group_change(group, cpu, clear, set, now, wake_clock);
 
 		/*
-		 * TSK_ONCPU is handled up to the common ancestor. If we're tasked
-		 * with dequeuing too, finish that for the rest of the hierarchy.
+		 * TSK_ONCPU is handled up to the common ancestor. If there are
+		 * any other differences between the two tasks (e.g. prev goes
+		 * to sleep, or only one task is memstall), finish propagating
+		 * those differences all the way up to the root.
 		 */
-		if (sleep) {
+		if ((prev->psi_flags ^ next->psi_flags) & ~TSK_ONCPU) {
 			clear &= ~TSK_ONCPU;
 			for (; group; group = iterate_groups(prev, &iter))
 				psi_group_change(group, cpu, clear, set, now, wake_clock);