diff --git a/include/net/tls.h b/include/net/tls.h
index 9f4117ae22973e30696aa42623aef44d8a3a6c57..a5a938583295c0789df287c737b7a1c87556c9f1 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -199,10 +199,6 @@ struct tls_offload_context_tx {
 	(ALIGN(sizeof(struct tls_offload_context_tx), sizeof(void *)) +        \
 	 TLS_DRIVER_STATE_SIZE)
 
-enum {
-	TLS_PENDING_CLOSED_RECORD
-};
-
 struct cipher_context {
 	char *iv;
 	char *rec_seq;
@@ -335,17 +331,14 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx,
 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
 			    int flags);
 
-int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx,
-				   int flags, long *timeo);
-
 static inline struct tls_msg *tls_msg(struct sk_buff *skb)
 {
 	return (struct tls_msg *)strp_msg(skb);
 }
 
-static inline bool tls_is_pending_closed_record(struct tls_context *ctx)
+static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
 {
-	return test_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
+	return !!ctx->partially_sent_record;
 }
 
 static inline int tls_complete_pending_work(struct sock *sk,
@@ -357,17 +350,12 @@ static inline int tls_complete_pending_work(struct sock *sk,
 	if (unlikely(sk->sk_write_pending))
 		rc = wait_on_pending_writer(sk, timeo);
 
-	if (!rc && tls_is_pending_closed_record(ctx))
-		rc = tls_push_pending_closed_record(sk, ctx, flags, timeo);
+	if (!rc && tls_is_partially_sent_record(ctx))
+		rc = tls_push_partial_record(sk, ctx, flags);
 
 	return rc;
 }
 
-static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
-{
-	return !!ctx->partially_sent_record;
-}
-
 static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
 {
 	return tls_ctx->pending_open_record_frags;
@@ -531,6 +519,9 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk)
 	return !!tls_sw_ctx_tx(ctx);
 }
 
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
+void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
+
 static inline struct tls_offload_context_rx *
 tls_offload_ctx_rx(const struct tls_context *tls_ctx)
 {
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index a5c17c47d08a8d97c4bc7bc8c86c8f96939402ea..4a1da837a733d0ec583b7b6f33f3c73466812645 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -271,7 +271,6 @@ static int tls_push_record(struct sock *sk,
 	list_add_tail(&record->list, &offload_ctx->records_list);
 	spin_unlock_irq(&offload_ctx->lock);
 	offload_ctx->open_record = NULL;
-	set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
 	tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version);
 
 	for (i = 0; i < record->num_frags; i++) {
@@ -368,9 +367,11 @@ static int tls_push_data(struct sock *sk,
 		return -sk->sk_err;
 
 	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
-	rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
-	if (rc < 0)
-		return rc;
+	if (tls_is_partially_sent_record(tls_ctx)) {
+		rc = tls_push_partial_record(sk, tls_ctx, flags);
+		if (rc < 0)
+			return rc;
+	}
 
 	pfrag = sk_page_frag(sk);
 
@@ -545,6 +546,23 @@ static int tls_device_push_pending_record(struct sock *sk, int flags)
 	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
 }
 
+void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
+{
+	int rc = 0;
+
+	if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) {
+		gfp_t sk_allocation = sk->sk_allocation;
+
+		sk->sk_allocation = GFP_ATOMIC;
+		rc = tls_push_partial_record(sk, ctx,
+					     MSG_DONTWAIT | MSG_NOSIGNAL);
+		sk->sk_allocation = sk_allocation;
+	}
+
+	if (!rc)
+		ctx->sk_write_space(sk);
+}
+
 void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index caff15b2f9b263b665157ea048eab494049aa1d8..17e8667917aa3c14fbe0111f426072c559db053e 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -209,23 +209,9 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
 	return tls_push_sg(sk, ctx, sg, offset, flags);
 }
 
-int tls_push_pending_closed_record(struct sock *sk,
-				   struct tls_context *tls_ctx,
-				   int flags, long *timeo)
-{
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-
-	if (tls_is_partially_sent_record(tls_ctx) ||
-	    !list_empty(&ctx->tx_list))
-		return tls_tx_records(sk, flags);
-	else
-		return tls_ctx->push_pending_record(sk, flags);
-}
-
 static void tls_write_space(struct sock *sk)
 {
 	struct tls_context *ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
 
 	/* If in_tcp_sendpages call lower protocol write space handler
 	 * to ensure we wake up any waiting operations there. For example
@@ -236,14 +222,12 @@ static void tls_write_space(struct sock *sk)
 		return;
 	}
 
-	/* Schedule the transmission if tx list is ready */
-	if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) {
-		/* Schedule the transmission */
-		if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
-			schedule_delayed_work(&tx_ctx->tx_work.work, 0);
-	}
-
-	ctx->sk_write_space(sk);
+#ifdef CONFIG_TLS_DEVICE
+	if (ctx->tx_conf == TLS_HW)
+		tls_device_write_space(sk, ctx);
+	else
+#endif
+		tls_sw_write_space(sk, ctx);
 }
 
 static void tls_ctx_free(struct tls_context *ctx)
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 1cc830582fa8af2e0ff978005eba5c5a609ae51b..425351ac2a9b156aacf9234e566d7c5ba0dc5867 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -1467,23 +1467,26 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
 	struct strp_msg *rxm = strp_msg(skb);
 	int err = 0;
 
+	if (!ctx->decrypted) {
 #ifdef CONFIG_TLS_DEVICE
-	err = tls_device_decrypted(sk, skb);
-	if (err < 0)
-		return err;
+		err = tls_device_decrypted(sk, skb);
+		if (err < 0)
+			return err;
 #endif
-	if (!ctx->decrypted) {
-		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
-		if (err < 0) {
-			if (err == -EINPROGRESS)
-				tls_advance_record_sn(sk, &tls_ctx->rx,
-						      version);
+		/* Still not decrypted after tls_device */
+		if (!ctx->decrypted) {
+			err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
+					       async);
+			if (err < 0) {
+				if (err == -EINPROGRESS)
+					tls_advance_record_sn(sk, &tls_ctx->rx,
+							      version);
 
-			return err;
+				return err;
+			}
 		}
 
 		rxm->full_len -= padding_length(ctx, tls_ctx, skb);
-
 		rxm->offset += prot->prepend_size;
 		rxm->full_len -= prot->overhead_size;
 		tls_advance_record_sn(sk, &tls_ctx->rx, version);
@@ -1693,7 +1696,8 @@ int tls_sw_recvmsg(struct sock *sk,
 		bool zc = false;
 		int to_decrypt;
 		int chunk = 0;
-		bool async;
+		bool async_capable;
+		bool async = false;
 
 		skb = tls_wait_data(sk, psock, flags, timeo, &err);
 		if (!skb) {
@@ -1727,21 +1731,23 @@ int tls_sw_recvmsg(struct sock *sk,
 
 		/* Do not use async mode if record is non-data */
 		if (ctx->control == TLS_RECORD_TYPE_DATA)
-			async = ctx->async_capable;
+			async_capable = ctx->async_capable;
 		else
-			async = false;
+			async_capable = false;
 
 		err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-					 &chunk, &zc, async);
+					 &chunk, &zc, async_capable);
 		if (err < 0 && err != -EINPROGRESS) {
 			tls_err_abort(sk, EBADMSG);
 			goto recv_end;
 		}
 
-		if (err == -EINPROGRESS)
+		if (err == -EINPROGRESS) {
+			async = true;
 			num_async++;
-		else if (prot->version == TLS_1_3_VERSION)
+		} else if (prot->version == TLS_1_3_VERSION) {
 			tlm->control = ctx->control;
+		}
 
 		/* If the type of records being processed is not known yet,
 		 * set it to record type just dequeued. If it is already known,
@@ -2126,6 +2132,19 @@ static void tx_work_handler(struct work_struct *work)
 	release_sock(sk);
 }
 
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
+{
+	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
+
+	/* Schedule the transmission if tx list is ready */
+	if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) {
+		/* Schedule the transmission */
+		if (!test_and_set_bit(BIT_TX_SCHEDULED,
+				      &tx_ctx->tx_bitmask))
+			schedule_delayed_work(&tx_ctx->tx_work.work, 0);
+	}
+}
+
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);