crypto: sha3-generic - simplify code

In preparation of exposing the generic SHA3 implementation to other
versions as a fallback, simplify the code, and remove an inconsistency
in the output handling (endian swabbing rsizw words of state before
writing the output does not make sense)

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff --git a/crypto/sha3_generic.c b/crypto/sha3_generic.c
index 5fecb60..c7084a2 100644
--- a/crypto/sha3_generic.c
+++ b/crypto/sha3_generic.c
@@ -18,7 +18,6 @@
 #include <linux/module.h>
 #include <linux/types.h>
 #include <crypto/sha3.h>
-#include <asm/byteorder.h>
 #include <asm/unaligned.h>
 
 #define KECCAK_ROUNDS 24
@@ -146,43 +145,16 @@ static void __attribute__((__optimize__("O3"))) keccakf(u64 st[25])
 	}
 }
 
-static void sha3_init(struct sha3_state *sctx, unsigned int digest_sz)
+static int sha3_init(struct shash_desc *desc)
 {
-	memset(sctx, 0, sizeof(*sctx));
-	sctx->md_len = digest_sz;
-	sctx->rsiz = 200 - 2 * digest_sz;
+	struct sha3_state *sctx = shash_desc_ctx(desc);
+	unsigned int digest_size = crypto_shash_digestsize(desc->tfm);
+
+	sctx->rsiz = 200 - 2 * digest_size;
 	sctx->rsizw = sctx->rsiz / 8;
-}
+	sctx->partial = 0;
 
-static int sha3_224_init(struct shash_desc *desc)
-{
-	struct sha3_state *sctx = shash_desc_ctx(desc);
-
-	sha3_init(sctx, SHA3_224_DIGEST_SIZE);
-	return 0;
-}
-
-static int sha3_256_init(struct shash_desc *desc)
-{
-	struct sha3_state *sctx = shash_desc_ctx(desc);
-
-	sha3_init(sctx, SHA3_256_DIGEST_SIZE);
-	return 0;
-}
-
-static int sha3_384_init(struct shash_desc *desc)
-{
-	struct sha3_state *sctx = shash_desc_ctx(desc);
-
-	sha3_init(sctx, SHA3_384_DIGEST_SIZE);
-	return 0;
-}
-
-static int sha3_512_init(struct shash_desc *desc)
-{
-	struct sha3_state *sctx = shash_desc_ctx(desc);
-
-	sha3_init(sctx, SHA3_512_DIGEST_SIZE);
+	memset(sctx->st, 0, sizeof(sctx->st));
 	return 0;
 }
 
@@ -227,6 +199,8 @@ static int sha3_final(struct shash_desc *desc, u8 *out)
 {
 	struct sha3_state *sctx = shash_desc_ctx(desc);
 	unsigned int i, inlen = sctx->partial;
+	unsigned int digest_size = crypto_shash_digestsize(desc->tfm);
+	__le64 *digest = (__le64 *)out;
 
 	sctx->buf[inlen++] = 0x06;
 	memset(sctx->buf + inlen, 0, sctx->rsiz - inlen);
@@ -237,110 +211,70 @@ static int sha3_final(struct shash_desc *desc, u8 *out)
 
 	keccakf(sctx->st);
 
-	for (i = 0; i < sctx->rsizw; i++)
-		sctx->st[i] = cpu_to_le64(sctx->st[i]);
+	for (i = 0; i < digest_size / 8; i++)
+		put_unaligned_le64(sctx->st[i], digest++);
 
-	memcpy(out, sctx->st, sctx->md_len);
+	if (digest_size & 4)
+		put_unaligned_le32(sctx->st[i], (__le32 *)digest);
 
 	memset(sctx, 0, sizeof(*sctx));
 	return 0;
 }
 
-static struct shash_alg sha3_224 = {
-	.digestsize	=	SHA3_224_DIGEST_SIZE,
-	.init		=	sha3_224_init,
-	.update		=	sha3_update,
-	.final		=	sha3_final,
-	.descsize	=	sizeof(struct sha3_state),
-	.base		=	{
-		.cra_name	=	"sha3-224",
-		.cra_driver_name =	"sha3-224-generic",
-		.cra_flags	=	CRYPTO_ALG_TYPE_SHASH,
-		.cra_blocksize	=	SHA3_224_BLOCK_SIZE,
-		.cra_module	=	THIS_MODULE,
-	}
-};
-
-static struct shash_alg sha3_256 = {
-	.digestsize	=	SHA3_256_DIGEST_SIZE,
-	.init		=	sha3_256_init,
-	.update		=	sha3_update,
-	.final		=	sha3_final,
-	.descsize	=	sizeof(struct sha3_state),
-	.base		=	{
-		.cra_name	=	"sha3-256",
-		.cra_driver_name =	"sha3-256-generic",
-		.cra_flags	=	CRYPTO_ALG_TYPE_SHASH,
-		.cra_blocksize	=	SHA3_256_BLOCK_SIZE,
-		.cra_module	=	THIS_MODULE,
-	}
-};
-
-static struct shash_alg sha3_384 = {
-	.digestsize	=	SHA3_384_DIGEST_SIZE,
-	.init		=	sha3_384_init,
-	.update		=	sha3_update,
-	.final		=	sha3_final,
-	.descsize	=	sizeof(struct sha3_state),
-	.base		=	{
-		.cra_name	=	"sha3-384",
-		.cra_driver_name =	"sha3-384-generic",
-		.cra_flags	=	CRYPTO_ALG_TYPE_SHASH,
-		.cra_blocksize	=	SHA3_384_BLOCK_SIZE,
-		.cra_module	=	THIS_MODULE,
-	}
-};
-
-static struct shash_alg sha3_512 = {
-	.digestsize	=	SHA3_512_DIGEST_SIZE,
-	.init		=	sha3_512_init,
-	.update		=	sha3_update,
-	.final		=	sha3_final,
-	.descsize	=	sizeof(struct sha3_state),
-	.base		=	{
-		.cra_name	=	"sha3-512",
-		.cra_driver_name =	"sha3-512-generic",
-		.cra_flags	=	CRYPTO_ALG_TYPE_SHASH,
-		.cra_blocksize	=	SHA3_512_BLOCK_SIZE,
-		.cra_module	=	THIS_MODULE,
-	}
-};
+static struct shash_alg algs[] = { {
+	.digestsize		= SHA3_224_DIGEST_SIZE,
+	.init			= sha3_init,
+	.update			= sha3_update,
+	.final			= sha3_final,
+	.descsize		= sizeof(struct sha3_state),
+	.base.cra_name		= "sha3-224",
+	.base.cra_driver_name	= "sha3-224-generic",
+	.base.cra_flags		= CRYPTO_ALG_TYPE_SHASH,
+	.base.cra_blocksize	= SHA3_224_BLOCK_SIZE,
+	.base.cra_module	= THIS_MODULE,
+}, {
+	.digestsize		= SHA3_256_DIGEST_SIZE,
+	.init			= sha3_init,
+	.update			= sha3_update,
+	.final			= sha3_final,
+	.descsize		= sizeof(struct sha3_state),
+	.base.cra_name		= "sha3-256",
+	.base.cra_driver_name	= "sha3-256-generic",
+	.base.cra_flags		= CRYPTO_ALG_TYPE_SHASH,
+	.base.cra_blocksize	= SHA3_256_BLOCK_SIZE,
+	.base.cra_module	= THIS_MODULE,
+}, {
+	.digestsize		= SHA3_384_DIGEST_SIZE,
+	.init			= sha3_init,
+	.update			= sha3_update,
+	.final			= sha3_final,
+	.descsize		= sizeof(struct sha3_state),
+	.base.cra_name		= "sha3-384",
+	.base.cra_driver_name	= "sha3-384-generic",
+	.base.cra_flags		= CRYPTO_ALG_TYPE_SHASH,
+	.base.cra_blocksize	= SHA3_384_BLOCK_SIZE,
+	.base.cra_module	= THIS_MODULE,
+}, {
+	.digestsize		= SHA3_512_DIGEST_SIZE,
+	.init			= sha3_init,
+	.update			= sha3_update,
+	.final			= sha3_final,
+	.descsize		= sizeof(struct sha3_state),
+	.base.cra_name		= "sha3-512",
+	.base.cra_driver_name	= "sha3-512-generic",
+	.base.cra_flags		= CRYPTO_ALG_TYPE_SHASH,
+	.base.cra_blocksize	= SHA3_512_BLOCK_SIZE,
+	.base.cra_module	= THIS_MODULE,
+} };
 
 static int __init sha3_generic_mod_init(void)
 {
-	int ret;
-
-	ret = crypto_register_shash(&sha3_224);
-	if (ret < 0)
-		goto err_out;
-	ret = crypto_register_shash(&sha3_256);
-	if (ret < 0)
-		goto err_out_224;
-	ret = crypto_register_shash(&sha3_384);
-	if (ret < 0)
-		goto err_out_256;
-	ret = crypto_register_shash(&sha3_512);
-	if (ret < 0)
-		goto err_out_384;
-
-	return 0;
-
-err_out_384:
-	crypto_unregister_shash(&sha3_384);
-err_out_256:
-	crypto_unregister_shash(&sha3_256);
-err_out_224:
-	crypto_unregister_shash(&sha3_224);
-err_out:
-	return ret;
+	return crypto_register_shashes(algs, ARRAY_SIZE(algs));
 }
 
 static void __exit sha3_generic_mod_fini(void)
 {
-	crypto_unregister_shash(&sha3_224);
-	crypto_unregister_shash(&sha3_256);
-	crypto_unregister_shash(&sha3_384);
-	crypto_unregister_shash(&sha3_512);
+	crypto_unregister_shashes(algs, ARRAY_SIZE(algs));
 }
 
 module_init(sha3_generic_mod_init);