diff --git a/Documentation/networking/tls-offload.rst b/Documentation/networking/tls-offload.rst
index 048e5ca44824b4bd7304ee26b5a4c0b7185f828c..8a1eeb393316b56f7fee8c1b7b9a3cfed506ef12 100644
--- a/Documentation/networking/tls-offload.rst
+++ b/Documentation/networking/tls-offload.rst
@@ -513,3 +513,9 @@ Redirects leak clear text
 
 In the RX direction, if segment has already been decrypted by the device
 and it gets redirected or mirrored - clear text will be transmitted out.
+
+shutdown() doesn't clear TLS state
+----------------------------------
+
+shutdown() system call allows for a TLS socket to be reused as a different
+connection. Offload doesn't currently handle that.
diff --git a/include/net/tls.h b/include/net/tls.h
index 235508e35fd4538eef03a605e08584e57e09a0a5..9e425ac2de45a7197999fe9b7309fb58772e9419 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -271,6 +271,8 @@ struct tls_context {
 	unsigned long flags;
 
 	/* cache cold stuff */
+	struct proto *sk_proto;
+
 	void (*sk_destruct)(struct sock *sk);
 	void (*sk_proto_close)(struct sock *sk, long timeout);
 
@@ -288,6 +290,8 @@ struct tls_context {
 
 	struct list_head list;
 	refcount_t refcount;
+
+	struct work_struct gc;
 };
 
 enum tls_offload_ctx_dir {
@@ -359,7 +363,6 @@ void tls_sw_strparser_done(struct tls_context *tls_ctx);
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 int tls_sw_sendpage(struct sock *sk, struct page *page,
 		    int offset, size_t size, int flags);
-void tls_sw_close(struct sock *sk, long timeout);
 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
 void tls_sw_release_resources_tx(struct sock *sk);
 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index d152a00a7a27fd024251903218032d42242c7b78..48f1c26459d0478b24900afcbf0bfd8c1c6c5aa3 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -261,6 +261,33 @@ void tls_ctx_free(struct tls_context *ctx)
 	kfree(ctx);
 }
 
+static void tls_ctx_free_deferred(struct work_struct *gc)
+{
+	struct tls_context *ctx = container_of(gc, struct tls_context, gc);
+
+	/* Ensure any remaining work items are completed. The sk will
+	 * already have lost its tls_ctx reference by the time we get
+	 * here so no xmit operation will actually be performed.
+	 */
+	if (ctx->tx_conf == TLS_SW) {
+		tls_sw_cancel_work_tx(ctx);
+		tls_sw_free_ctx_tx(ctx);
+	}
+
+	if (ctx->rx_conf == TLS_SW) {
+		tls_sw_strparser_done(ctx);
+		tls_sw_free_ctx_rx(ctx);
+	}
+
+	tls_ctx_free(ctx);
+}
+
+static void tls_ctx_free_wq(struct tls_context *ctx)
+{
+	INIT_WORK(&ctx->gc, tls_ctx_free_deferred);
+	schedule_work(&ctx->gc);
+}
+
 static void tls_sk_proto_cleanup(struct sock *sk,
 				 struct tls_context *ctx, long timeo)
 {
@@ -288,6 +315,26 @@ static void tls_sk_proto_cleanup(struct sock *sk,
 #endif
 }
 
+static void tls_sk_proto_unhash(struct sock *sk)
+{
+	struct inet_connection_sock *icsk = inet_csk(sk);
+	long timeo = sock_sndtimeo(sk, 0);
+	struct tls_context *ctx;
+
+	if (unlikely(!icsk->icsk_ulp_data)) {
+		if (sk->sk_prot->unhash)
+			sk->sk_prot->unhash(sk);
+	}
+
+	ctx = tls_get_ctx(sk);
+	tls_sk_proto_cleanup(sk, ctx, timeo);
+	icsk->icsk_ulp_data = NULL;
+
+	if (ctx->sk_proto->unhash)
+		ctx->sk_proto->unhash(sk);
+	tls_ctx_free_wq(ctx);
+}
+
 static void tls_sk_proto_close(struct sock *sk, long timeout)
 {
 	void (*sk_proto_close)(struct sock *sk, long timeout);
@@ -305,6 +352,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
 		tls_sk_proto_cleanup(sk, ctx, timeo);
 
+	sk->sk_prot = ctx->sk_proto;
 	release_sock(sk);
 	if (ctx->tx_conf == TLS_SW)
 		tls_sw_free_ctx_tx(ctx);
@@ -608,6 +656,7 @@ static struct tls_context *create_ctx(struct sock *sk)
 	ctx->setsockopt = sk->sk_prot->setsockopt;
 	ctx->getsockopt = sk->sk_prot->getsockopt;
 	ctx->sk_proto_close = sk->sk_prot->close;
+	ctx->unhash = sk->sk_prot->unhash;
 	return ctx;
 }
 
@@ -731,6 +780,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
 	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt;
 	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close;
+	prot[TLS_BASE][TLS_BASE].unhash		= tls_sk_proto_unhash;
 
 	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg;
@@ -748,16 +798,20 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 
 #ifdef CONFIG_TLS_DEVICE
 	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+	prot[TLS_HW][TLS_BASE].unhash		= base->unhash;
 	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg;
 	prot[TLS_HW][TLS_BASE].sendpage		= tls_device_sendpage;
 
 	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
+	prot[TLS_HW][TLS_SW].unhash		= base->unhash;
 	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
 	prot[TLS_HW][TLS_SW].sendpage		= tls_device_sendpage;
 
 	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+	prot[TLS_BASE][TLS_HW].unhash		= base->unhash;
 
 	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+	prot[TLS_SW][TLS_HW].unhash		= base->unhash;
 
 	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 #endif
@@ -794,6 +848,7 @@ static int tls_init(struct sock *sk)
 	tls_build_proto(sk);
 	ctx->tx_conf = TLS_BASE;
 	ctx->rx_conf = TLS_BASE;
+	ctx->sk_proto = sk->sk_prot;
 	update_sk_prot(sk, ctx);
 out:
 	return rc;