diff --git a/include/linux/pid.h b/include/linux/pid.h
index af308e15f174c9866d8a594375cffcfa8d69a615..343abf22092e6b54a973dc6d457e06e317e6c9ed 100644
--- a/include/linux/pid.h
+++ b/include/linux/pid.h
@@ -78,6 +78,7 @@ struct file;
 
 extern struct pid *pidfd_pid(const struct file *file);
 struct pid *pidfd_get_pid(unsigned int fd, unsigned int *flags);
+struct task_struct *pidfd_get_task(int pidfd, unsigned int *flags);
 int pidfd_create(struct pid *pid, unsigned int flags);
 
 static inline struct pid *get_pid(struct pid *pid)
diff --git a/kernel/pid.c b/kernel/pid.c
index efe87db4468364f8d2650d0b517aa2b567048f2c..2fc0a16ec77b1d9b6efe3b6e098f7f23921abe17 100644
--- a/kernel/pid.c
+++ b/kernel/pid.c
@@ -539,6 +539,42 @@ struct pid *pidfd_get_pid(unsigned int fd, unsigned int *flags)
 	return pid;
 }
 
+/**
+ * pidfd_get_task() - Get the task associated with a pidfd
+ *
+ * @pidfd: pidfd for which to get the task
+ * @flags: flags associated with this pidfd
+ *
+ * Return the task associated with @pidfd. The function takes a reference on
+ * the returned task. The caller is responsible for releasing that reference.
+ *
+ * Currently, the process identified by @pidfd is always a thread-group leader.
+ * This restriction currently exists for all aspects of pidfds including pidfd
+ * creation (CLONE_PIDFD cannot be used with CLONE_THREAD) and pidfd polling
+ * (only supports thread group leaders).
+ *
+ * Return: On success, the task_struct associated with the pidfd.
+ *	   On error, a negative errno number will be returned.
+ */
+struct task_struct *pidfd_get_task(int pidfd, unsigned int *flags)
+{
+	unsigned int f_flags;
+	struct pid *pid;
+	struct task_struct *task;
+
+	pid = pidfd_get_pid(pidfd, &f_flags);
+	if (IS_ERR(pid))
+		return ERR_CAST(pid);
+
+	task = get_pid_task(pid, PIDTYPE_TGID);
+	put_pid(pid);
+	if (!task)
+		return ERR_PTR(-ESRCH);
+
+	*flags = f_flags;
+	return task;
+}
+
 /**
  * pidfd_create() - Create a new pid file descriptor.
  *
diff --git a/mm/madvise.c b/mm/madvise.c
index 0734db8d53a7a9e8ebb65112c6ae76e9c9b07f2d..8c927202bbe61d389777f6668d692bc7871a1547 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -1235,7 +1235,6 @@ SYSCALL_DEFINE5(process_madvise, int, pidfd, const struct iovec __user *, vec,
 	struct iovec iovstack[UIO_FASTIOV], iovec;
 	struct iovec *iov = iovstack;
 	struct iov_iter iter;
-	struct pid *pid;
 	struct task_struct *task;
 	struct mm_struct *mm;
 	size_t total_len;
@@ -1250,18 +1249,12 @@ SYSCALL_DEFINE5(process_madvise, int, pidfd, const struct iovec __user *, vec,
 	if (ret < 0)
 		goto out;
 
-	pid = pidfd_get_pid(pidfd, &f_flags);
-	if (IS_ERR(pid)) {
-		ret = PTR_ERR(pid);
+	task = pidfd_get_task(pidfd, &f_flags);
+	if (IS_ERR(task)) {
+		ret = PTR_ERR(task);
 		goto free_iov;
 	}
 
-	task = get_pid_task(pid, PIDTYPE_PID);
-	if (!task) {
-		ret = -ESRCH;
-		goto put_pid;
-	}
-
 	if (!process_madvise_behavior_valid(behavior)) {
 		ret = -EINVAL;
 		goto release_task;
@@ -1301,8 +1294,6 @@ SYSCALL_DEFINE5(process_madvise, int, pidfd, const struct iovec __user *, vec,
 	mmput(mm);
 release_task:
 	put_task_struct(task);
-put_pid:
-	put_pid(pid);
 free_iov:
 	kfree(iov);
 out:
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 195b3661da3d746ccbd73b123d974131695903d5..1ddabefcfb5aba566205264d127702976e622ec6 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -1150,21 +1150,14 @@ SYSCALL_DEFINE2(process_mrelease, int, pidfd, unsigned int, flags)
 	struct task_struct *p;
 	unsigned int f_flags;
 	bool reap = false;
-	struct pid *pid;
 	long ret = 0;
 
 	if (flags)
 		return -EINVAL;
 
-	pid = pidfd_get_pid(pidfd, &f_flags);
-	if (IS_ERR(pid))
-		return PTR_ERR(pid);
-
-	task = get_pid_task(pid, PIDTYPE_TGID);
-	if (!task) {
-		ret = -ESRCH;
-		goto put_pid;
-	}
+	task = pidfd_get_task(pidfd, &f_flags);
+	if (IS_ERR(task))
+		return PTR_ERR(task);
 
 	/*
 	 * Make sure to choose a thread which still has a reference to mm
@@ -1204,8 +1197,6 @@ SYSCALL_DEFINE2(process_mrelease, int, pidfd, unsigned int, flags)
 		mmput(mm);
 put_task:
 	put_task_struct(task);
-put_pid:
-	put_pid(pid);
 	return ret;
 #else
 	return -ENOSYS;