diff --git a/arch/arm64/crypto/aes-neonbs-core.S b/arch/arm64/crypto/aes-neonbs-core.S
index 8d0cdaa2768dfb2306ba533964cac36c39d4ec33..ca047250043371daac6dfe8a9876af8e228081d4 100644
--- a/arch/arm64/crypto/aes-neonbs-core.S
+++ b/arch/arm64/crypto/aes-neonbs-core.S
@@ -853,13 +853,15 @@ ENDPROC(aesbs_xts_decrypt)
 
 	/*
 	 * aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
-	 *		     int rounds, int blocks, u8 iv[], bool final)
+	 *		     int rounds, int blocks, u8 iv[], u8 final[])
 	 */
 ENTRY(aesbs_ctr_encrypt)
 	stp		x29, x30, [sp, #-16]!
 	mov		x29, sp
 
-	add		x4, x4, x6		// do one extra block if final
+	cmp		x6, #0
+	cset		x10, ne
+	add		x4, x4, x10		// do one extra block if final
 
 	ldp		x7, x8, [x5]
 	ld1		{v0.16b}, [x5]
@@ -874,19 +876,26 @@ CPU_LE(	rev		x8, x8		)
 	csel		x4, x4, xzr, pl
 	csel		x9, x9, xzr, le
 
+	tbnz		x9, #1, 0f
 	next_ctr	v1
+	tbnz		x9, #2, 0f
 	next_ctr	v2
+	tbnz		x9, #3, 0f
 	next_ctr	v3
+	tbnz		x9, #4, 0f
 	next_ctr	v4
+	tbnz		x9, #5, 0f
 	next_ctr	v5
+	tbnz		x9, #6, 0f
 	next_ctr	v6
+	tbnz		x9, #7, 0f
 	next_ctr	v7
 
 0:	mov		bskey, x2
 	mov		rounds, x3
 	bl		aesbs_encrypt8
 
-	lsr		x9, x9, x6		// disregard the extra block
+	lsr		x9, x9, x10		// disregard the extra block
 	tbnz		x9, #0, 0f
 
 	ld1		{v8.16b}, [x1], #16
@@ -928,36 +937,36 @@ CPU_LE(	rev		x8, x8		)
 	eor		v5.16b, v5.16b, v15.16b
 	st1		{v5.16b}, [x0], #16
 
-	next_ctr	v0
+8:	next_ctr	v0
 	cbnz		x4, 99b
 
 0:	st1		{v0.16b}, [x5]
-8:	ldp		x29, x30, [sp], #16
+	ldp		x29, x30, [sp], #16
 	ret
 
 	/*
-	 * If we are handling the tail of the input (x6 == 1), return the
-	 * final keystream block back to the caller via the IV buffer.
+	 * If we are handling the tail of the input (x6 != NULL), return the
+	 * final keystream block back to the caller.
 	 */
 1:	cbz		x6, 8b
-	st1		{v1.16b}, [x5]
+	st1		{v1.16b}, [x6]
 	b		8b
 2:	cbz		x6, 8b
-	st1		{v4.16b}, [x5]
+	st1		{v4.16b}, [x6]
 	b		8b
 3:	cbz		x6, 8b
-	st1		{v6.16b}, [x5]
+	st1		{v6.16b}, [x6]
 	b		8b
 4:	cbz		x6, 8b
-	st1		{v3.16b}, [x5]
+	st1		{v3.16b}, [x6]
 	b		8b
 5:	cbz		x6, 8b
-	st1		{v7.16b}, [x5]
+	st1		{v7.16b}, [x6]
 	b		8b
 6:	cbz		x6, 8b
-	st1		{v2.16b}, [x5]
+	st1		{v2.16b}, [x6]
 	b		8b
 7:	cbz		x6, 8b
-	st1		{v5.16b}, [x5]
+	st1		{v5.16b}, [x6]
 	b		8b
 ENDPROC(aesbs_ctr_encrypt)
diff --git a/arch/arm64/crypto/aes-neonbs-glue.c b/arch/arm64/crypto/aes-neonbs-glue.c
index 863e436ecf894375a2ddb90f3a69d64accd4675a..db2501d93550c35720c3d1adeb78fed49a9df1a2 100644
--- a/arch/arm64/crypto/aes-neonbs-glue.c
+++ b/arch/arm64/crypto/aes-neonbs-glue.c
@@ -34,7 +34,7 @@ asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
 				  int rounds, int blocks, u8 iv[]);
 
 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
-				  int rounds, int blocks, u8 iv[], bool final);
+				  int rounds, int blocks, u8 iv[], u8 final[]);
 
 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
 				  int rounds, int blocks, u8 iv[]);
@@ -201,6 +201,7 @@ static int ctr_encrypt(struct skcipher_request *req)
 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
 	struct skcipher_walk walk;
+	u8 buf[AES_BLOCK_SIZE];
 	int err;
 
 	err = skcipher_walk_virt(&walk, req, true);
@@ -208,12 +209,12 @@ static int ctr_encrypt(struct skcipher_request *req)
 	kernel_neon_begin();
 	while (walk.nbytes > 0) {
 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
-		bool final = (walk.total % AES_BLOCK_SIZE) != 0;
+		u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
 
 		if (walk.nbytes < walk.total) {
 			blocks = round_down(blocks,
 					    walk.stride / AES_BLOCK_SIZE);
-			final = false;
+			final = NULL;
 		}
 
 		aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
@@ -225,7 +226,7 @@ static int ctr_encrypt(struct skcipher_request *req)
 
 			if (dst != src)
 				memcpy(dst, src, walk.total % AES_BLOCK_SIZE);
-			crypto_xor(dst, walk.iv, walk.total % AES_BLOCK_SIZE);
+			crypto_xor(dst, final, walk.total % AES_BLOCK_SIZE);
 
 			err = skcipher_walk_done(&walk, 0);
 			break;