diff --git a/include/net/sock.h b/include/net/sock.h
index ecea3dcc2217b8373111b30e93231329b562fc71..fefe1f4abf191de3e0b8719e5e50d83a0e28ef20 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -318,6 +318,9 @@ struct sk_filter;
   *	@sk_stamp: time stamp of last packet received
   *	@sk_stamp_seq: lock for accessing sk_stamp on 32 bit architectures only
   *	@sk_tsflags: SO_TIMESTAMPING flags
+  *	@sk_use_task_frag: allow sk_page_frag() to use current->task_frag.
+  *			   Sockets that can be used under memory reclaim should
+  *			   set this to false.
   *	@sk_bind_phc: SO_TIMESTAMPING bind PHC index of PTP virtual clock
   *	              for timestamping
   *	@sk_tskey: counter to disambiguate concurrent tstamp requests
@@ -512,6 +515,7 @@ struct sock {
 	u8			sk_txtime_deadline_mode : 1,
 				sk_txtime_report_errors : 1,
 				sk_txtime_unused : 6;
+	bool			sk_use_task_frag;
 
 	struct socket		*sk_socket;
 	void			*sk_user_data;
@@ -2561,14 +2565,17 @@ static inline void sk_stream_moderate_sndbuf(struct sock *sk)
  * socket operations and end up recursing into sk_page_frag()
  * while it's already in use: explicitly avoid task page_frag
  * usage if the caller is potentially doing any of them.
- * This assumes that page fault handlers use the GFP_NOFS flags.
+ * This assumes that page fault handlers use the GFP_NOFS flags or
+ * explicitly disable sk_use_task_frag.
  *
  * Return: a per task page_frag if context allows that,
  * otherwise a per socket one.
  */
 static inline struct page_frag *sk_page_frag(struct sock *sk)
 {
-	if ((sk->sk_allocation & (__GFP_DIRECT_RECLAIM | __GFP_MEMALLOC | __GFP_FS)) ==
+	if (sk->sk_use_task_frag &&
+	    (sk->sk_allocation & (__GFP_DIRECT_RECLAIM | __GFP_MEMALLOC |
+				  __GFP_FS)) ==
 	    (__GFP_DIRECT_RECLAIM | __GFP_FS))
 		return &current->task_frag;
 
diff --git a/net/core/sock.c b/net/core/sock.c
index d2587d8712dba6217aeedb486b43fe4ec0b956f1..f954d5893e799ac51aca97d5cb2314cc32460c92 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -3390,6 +3390,7 @@ void sock_init_data(struct socket *sock, struct sock *sk)
 	sk->sk_rcvbuf		=	READ_ONCE(sysctl_rmem_default);
 	sk->sk_sndbuf		=	READ_ONCE(sysctl_wmem_default);
 	sk->sk_state		=	TCP_CLOSE;
+	sk->sk_use_task_frag	=	true;
 	sk_set_socket(sk, sock);
 
 	sock_set_flag(sk, SOCK_ZAPPED);