diff --git a/crypto/ahash.c b/crypto/ahash.c
index d2c8895bb2feacfe5ec510ec0360a9fd9cade0e0..266fc1d64f61d537430059f38f3234be7f3a17e1 100644
--- a/crypto/ahash.c
+++ b/crypto/ahash.c
@@ -193,11 +193,18 @@ int crypto_ahash_setkey(struct crypto_ahash *tfm, const u8 *key,
 			unsigned int keylen)
 {
 	unsigned long alignmask = crypto_ahash_alignmask(tfm);
+	int err;
 
 	if ((unsigned long)key & alignmask)
-		return ahash_setkey_unaligned(tfm, key, keylen);
+		err = ahash_setkey_unaligned(tfm, key, keylen);
+	else
+		err = tfm->setkey(tfm, key, keylen);
+
+	if (err)
+		return err;
 
-	return tfm->setkey(tfm, key, keylen);
+	crypto_ahash_clear_flags(tfm, CRYPTO_TFM_NEED_KEY);
+	return 0;
 }
 EXPORT_SYMBOL_GPL(crypto_ahash_setkey);
 
@@ -368,7 +375,12 @@ EXPORT_SYMBOL_GPL(crypto_ahash_finup);
 
 int crypto_ahash_digest(struct ahash_request *req)
 {
-	return crypto_ahash_op(req, crypto_ahash_reqtfm(req)->digest);
+	struct crypto_ahash *tfm = crypto_ahash_reqtfm(req);
+
+	if (crypto_ahash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
+	return crypto_ahash_op(req, tfm->digest);
 }
 EXPORT_SYMBOL_GPL(crypto_ahash_digest);
 
@@ -450,7 +462,6 @@ static int crypto_ahash_init_tfm(struct crypto_tfm *tfm)
 	struct ahash_alg *alg = crypto_ahash_alg(hash);
 
 	hash->setkey = ahash_nosetkey;
-	hash->has_setkey = false;
 	hash->export = ahash_no_export;
 	hash->import = ahash_no_import;
 
@@ -465,7 +476,8 @@ static int crypto_ahash_init_tfm(struct crypto_tfm *tfm)
 
 	if (alg->setkey) {
 		hash->setkey = alg->setkey;
-		hash->has_setkey = true;
+		if (!(alg->halg.base.cra_flags & CRYPTO_ALG_OPTIONAL_KEY))
+			crypto_ahash_set_flags(hash, CRYPTO_TFM_NEED_KEY);
 	}
 	if (alg->export)
 		hash->export = alg->export;
diff --git a/crypto/algif_hash.c b/crypto/algif_hash.c
index 76d2e716c7925afec7f4427eb7003f946bca26a8..6c9b1927a52084a909093b40ef47fe92b2703db0 100644
--- a/crypto/algif_hash.c
+++ b/crypto/algif_hash.c
@@ -34,11 +34,6 @@ struct hash_ctx {
 	struct ahash_request req;
 };
 
-struct algif_hash_tfm {
-	struct crypto_ahash *hash;
-	bool has_key;
-};
-
 static int hash_alloc_result(struct sock *sk, struct hash_ctx *ctx)
 {
 	unsigned ds;
@@ -307,7 +302,7 @@ static int hash_check_key(struct socket *sock)
 	int err = 0;
 	struct sock *psk;
 	struct alg_sock *pask;
-	struct algif_hash_tfm *tfm;
+	struct crypto_ahash *tfm;
 	struct sock *sk = sock->sk;
 	struct alg_sock *ask = alg_sk(sk);
 
@@ -321,7 +316,7 @@ static int hash_check_key(struct socket *sock)
 
 	err = -ENOKEY;
 	lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
-	if (!tfm->has_key)
+	if (crypto_ahash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
 		goto unlock;
 
 	if (!pask->refcnt++)
@@ -412,41 +407,17 @@ static struct proto_ops algif_hash_ops_nokey = {
 
 static void *hash_bind(const char *name, u32 type, u32 mask)
 {
-	struct algif_hash_tfm *tfm;
-	struct crypto_ahash *hash;
-
-	tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
-	if (!tfm)
-		return ERR_PTR(-ENOMEM);
-
-	hash = crypto_alloc_ahash(name, type, mask);
-	if (IS_ERR(hash)) {
-		kfree(tfm);
-		return ERR_CAST(hash);
-	}
-
-	tfm->hash = hash;
-
-	return tfm;
+	return crypto_alloc_ahash(name, type, mask);
 }
 
 static void hash_release(void *private)
 {
-	struct algif_hash_tfm *tfm = private;
-
-	crypto_free_ahash(tfm->hash);
-	kfree(tfm);
+	crypto_free_ahash(private);
 }
 
 static int hash_setkey(void *private, const u8 *key, unsigned int keylen)
 {
-	struct algif_hash_tfm *tfm = private;
-	int err;
-
-	err = crypto_ahash_setkey(tfm->hash, key, keylen);
-	tfm->has_key = !err;
-
-	return err;
+	return crypto_ahash_setkey(private, key, keylen);
 }
 
 static void hash_sock_destruct(struct sock *sk)
@@ -461,11 +432,10 @@ static void hash_sock_destruct(struct sock *sk)
 
 static int hash_accept_parent_nokey(void *private, struct sock *sk)
 {
-	struct hash_ctx *ctx;
+	struct crypto_ahash *tfm = private;
 	struct alg_sock *ask = alg_sk(sk);
-	struct algif_hash_tfm *tfm = private;
-	struct crypto_ahash *hash = tfm->hash;
-	unsigned len = sizeof(*ctx) + crypto_ahash_reqsize(hash);
+	struct hash_ctx *ctx;
+	unsigned int len = sizeof(*ctx) + crypto_ahash_reqsize(tfm);
 
 	ctx = sock_kmalloc(sk, len, GFP_KERNEL);
 	if (!ctx)
@@ -478,7 +448,7 @@ static int hash_accept_parent_nokey(void *private, struct sock *sk)
 
 	ask->private = ctx;
 
-	ahash_request_set_tfm(&ctx->req, hash);
+	ahash_request_set_tfm(&ctx->req, tfm);
 	ahash_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
 				   crypto_req_done, &ctx->wait);
 
@@ -489,9 +459,9 @@ static int hash_accept_parent_nokey(void *private, struct sock *sk)
 
 static int hash_accept_parent(void *private, struct sock *sk)
 {
-	struct algif_hash_tfm *tfm = private;
+	struct crypto_ahash *tfm = private;
 
-	if (!tfm->has_key && crypto_ahash_has_setkey(tfm->hash))
+	if (crypto_ahash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
 		return -ENOKEY;
 
 	return hash_accept_parent_nokey(private, sk);
diff --git a/crypto/shash.c b/crypto/shash.c
index e849d3ee2e2728d346df1f21f6a8d4db57fc42c5..5d732c6bb4b2158f59e7e16e0f96508b8ae90f91 100644
--- a/crypto/shash.c
+++ b/crypto/shash.c
@@ -58,11 +58,18 @@ int crypto_shash_setkey(struct crypto_shash *tfm, const u8 *key,
 {
 	struct shash_alg *shash = crypto_shash_alg(tfm);
 	unsigned long alignmask = crypto_shash_alignmask(tfm);
+	int err;
 
 	if ((unsigned long)key & alignmask)
-		return shash_setkey_unaligned(tfm, key, keylen);
+		err = shash_setkey_unaligned(tfm, key, keylen);
+	else
+		err = shash->setkey(tfm, key, keylen);
+
+	if (err)
+		return err;
 
-	return shash->setkey(tfm, key, keylen);
+	crypto_shash_clear_flags(tfm, CRYPTO_TFM_NEED_KEY);
+	return 0;
 }
 EXPORT_SYMBOL_GPL(crypto_shash_setkey);
 
@@ -181,6 +188,9 @@ int crypto_shash_digest(struct shash_desc *desc, const u8 *data,
 	struct shash_alg *shash = crypto_shash_alg(tfm);
 	unsigned long alignmask = crypto_shash_alignmask(tfm);
 
+	if (crypto_shash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
 	if (((unsigned long)data | (unsigned long)out) & alignmask)
 		return shash_digest_unaligned(desc, data, len, out);
 
@@ -360,7 +370,8 @@ int crypto_init_shash_ops_async(struct crypto_tfm *tfm)
 	crt->digest = shash_async_digest;
 	crt->setkey = shash_async_setkey;
 
-	crt->has_setkey = alg->setkey != shash_no_setkey;
+	crypto_ahash_set_flags(crt, crypto_shash_get_flags(shash) &
+				    CRYPTO_TFM_NEED_KEY);
 
 	if (alg->export)
 		crt->export = shash_async_export;
@@ -375,8 +386,14 @@ int crypto_init_shash_ops_async(struct crypto_tfm *tfm)
 static int crypto_shash_init_tfm(struct crypto_tfm *tfm)
 {
 	struct crypto_shash *hash = __crypto_shash_cast(tfm);
+	struct shash_alg *alg = crypto_shash_alg(hash);
+
+	hash->descsize = alg->descsize;
+
+	if (crypto_shash_alg_has_setkey(alg) &&
+	    !(alg->base.cra_flags & CRYPTO_ALG_OPTIONAL_KEY))
+		crypto_shash_set_flags(hash, CRYPTO_TFM_NEED_KEY);
 
-	hash->descsize = crypto_shash_alg(hash)->descsize;
 	return 0;
 }
 
diff --git a/include/crypto/hash.h b/include/crypto/hash.h
index 0ed31fd80242cf6dafda1605446e8c5c1d73af60..3880793e280eb821aff6733c63f1b8ed2cedf81f 100644
--- a/include/crypto/hash.h
+++ b/include/crypto/hash.h
@@ -210,7 +210,6 @@ struct crypto_ahash {
 		      unsigned int keylen);
 
 	unsigned int reqsize;
-	bool has_setkey;
 	struct crypto_tfm base;
 };
 
@@ -410,11 +409,6 @@ static inline void *ahash_request_ctx(struct ahash_request *req)
 int crypto_ahash_setkey(struct crypto_ahash *tfm, const u8 *key,
 			unsigned int keylen);
 
-static inline bool crypto_ahash_has_setkey(struct crypto_ahash *tfm)
-{
-	return tfm->has_setkey;
-}
-
 /**
  * crypto_ahash_finup() - update and finalize message digest
  * @req: reference to the ahash_request handle that holds all information
@@ -487,7 +481,12 @@ static inline int crypto_ahash_export(struct ahash_request *req, void *out)
  */
 static inline int crypto_ahash_import(struct ahash_request *req, const void *in)
 {
-	return crypto_ahash_reqtfm(req)->import(req, in);
+	struct crypto_ahash *tfm = crypto_ahash_reqtfm(req);
+
+	if (crypto_ahash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
+	return tfm->import(req, in);
 }
 
 /**
@@ -503,7 +502,12 @@ static inline int crypto_ahash_import(struct ahash_request *req, const void *in)
  */
 static inline int crypto_ahash_init(struct ahash_request *req)
 {
-	return crypto_ahash_reqtfm(req)->init(req);
+	struct crypto_ahash *tfm = crypto_ahash_reqtfm(req);
+
+	if (crypto_ahash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
+	return tfm->init(req);
 }
 
 /**
@@ -855,7 +859,12 @@ static inline int crypto_shash_export(struct shash_desc *desc, void *out)
  */
 static inline int crypto_shash_import(struct shash_desc *desc, const void *in)
 {
-	return crypto_shash_alg(desc->tfm)->import(desc, in);
+	struct crypto_shash *tfm = desc->tfm;
+
+	if (crypto_shash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
+	return crypto_shash_alg(tfm)->import(desc, in);
 }
 
 /**
@@ -871,7 +880,12 @@ static inline int crypto_shash_import(struct shash_desc *desc, const void *in)
  */
 static inline int crypto_shash_init(struct shash_desc *desc)
 {
-	return crypto_shash_alg(desc->tfm)->init(desc);
+	struct crypto_shash *tfm = desc->tfm;
+
+	if (crypto_shash_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
+		return -ENOKEY;
+
+	return crypto_shash_alg(tfm)->init(desc);
 }
 
 /**
diff --git a/include/linux/crypto.h b/include/linux/crypto.h
index d2e33a90825b4d358783879935e0a64df0192541..7e6e84cf6383525a6fe5eb68bd5fd6b49e9a18d8 100644
--- a/include/linux/crypto.h
+++ b/include/linux/crypto.h
@@ -115,6 +115,8 @@
 /*
  * Transform masks and values (for crt_flags).
  */
+#define CRYPTO_TFM_NEED_KEY		0x00000001
+
 #define CRYPTO_TFM_REQ_MASK		0x000fff00
 #define CRYPTO_TFM_RES_MASK		0xfff00000