crypto: algif_aead - Require setkey before accept(2)

Some cipher implementations will crash if you try to use them
without calling setkey first.  This patch adds a check so that
the accept(2) call will fail with -ENOKEY if setkey hasn't been
done on the socket yet.

Fixes: 400c40cf78da ("crypto: algif - add AEAD support")
Cc: <stable@vger.kernel.org>
Signed-off-by: Stephan Mueller <smueller@chronox.de>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff --git a/crypto/algif_aead.c b/crypto/algif_aead.c
index 5a80537..e0d55ea 100644
--- a/crypto/algif_aead.c
+++ b/crypto/algif_aead.c
@@ -44,6 +44,11 @@ struct aead_async_req {
 	char iv[];
 };
 
+struct aead_tfm {
+	struct crypto_aead *aead;
+	bool has_key;
+};
+
 struct aead_ctx {
 	struct aead_sg_list tsgl;
 	struct aead_async_rsgl first_rsgl;
@@ -723,24 +728,146 @@ static struct proto_ops algif_aead_ops = {
 	.poll		=	aead_poll,
 };
 
+static int aead_check_key(struct socket *sock)
+{
+	int err = 0;
+	struct sock *psk;
+	struct alg_sock *pask;
+	struct aead_tfm *tfm;
+	struct sock *sk = sock->sk;
+	struct alg_sock *ask = alg_sk(sk);
+
+	lock_sock(sk);
+	if (ask->refcnt)
+		goto unlock_child;
+
+	psk = ask->parent;
+	pask = alg_sk(ask->parent);
+	tfm = pask->private;
+
+	err = -ENOKEY;
+	lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
+	if (!tfm->has_key)
+		goto unlock;
+
+	if (!pask->refcnt++)
+		sock_hold(psk);
+
+	ask->refcnt = 1;
+	sock_put(psk);
+
+	err = 0;
+
+unlock:
+	release_sock(psk);
+unlock_child:
+	release_sock(sk);
+
+	return err;
+}
+
+static int aead_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
+				  size_t size)
+{
+	int err;
+
+	err = aead_check_key(sock);
+	if (err)
+		return err;
+
+	return aead_sendmsg(sock, msg, size);
+}
+
+static ssize_t aead_sendpage_nokey(struct socket *sock, struct page *page,
+				       int offset, size_t size, int flags)
+{
+	int err;
+
+	err = aead_check_key(sock);
+	if (err)
+		return err;
+
+	return aead_sendpage(sock, page, offset, size, flags);
+}
+
+static int aead_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
+				  size_t ignored, int flags)
+{
+	int err;
+
+	err = aead_check_key(sock);
+	if (err)
+		return err;
+
+	return aead_recvmsg(sock, msg, ignored, flags);
+}
+
+static struct proto_ops algif_aead_ops_nokey = {
+	.family		=	PF_ALG,
+
+	.connect	=	sock_no_connect,
+	.socketpair	=	sock_no_socketpair,
+	.getname	=	sock_no_getname,
+	.ioctl		=	sock_no_ioctl,
+	.listen		=	sock_no_listen,
+	.shutdown	=	sock_no_shutdown,
+	.getsockopt	=	sock_no_getsockopt,
+	.mmap		=	sock_no_mmap,
+	.bind		=	sock_no_bind,
+	.accept		=	sock_no_accept,
+	.setsockopt	=	sock_no_setsockopt,
+
+	.release	=	af_alg_release,
+	.sendmsg	=	aead_sendmsg_nokey,
+	.sendpage	=	aead_sendpage_nokey,
+	.recvmsg	=	aead_recvmsg_nokey,
+	.poll		=	aead_poll,
+};
+
 static void *aead_bind(const char *name, u32 type, u32 mask)
 {
-	return crypto_alloc_aead(name, type, mask);
+	struct aead_tfm *tfm;
+	struct crypto_aead *aead;
+
+	tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
+	if (!tfm)
+		return ERR_PTR(-ENOMEM);
+
+	aead = crypto_alloc_aead(name, type, mask);
+	if (IS_ERR(aead)) {
+		kfree(tfm);
+		return ERR_CAST(aead);
+	}
+
+	tfm->aead = aead;
+
+	return tfm;
 }
 
 static void aead_release(void *private)
 {
-	crypto_free_aead(private);
+	struct aead_tfm *tfm = private;
+
+	crypto_free_aead(tfm->aead);
+	kfree(tfm);
 }
 
 static int aead_setauthsize(void *private, unsigned int authsize)
 {
-	return crypto_aead_setauthsize(private, authsize);
+	struct aead_tfm *tfm = private;
+
+	return crypto_aead_setauthsize(tfm->aead, authsize);
 }
 
 static int aead_setkey(void *private, const u8 *key, unsigned int keylen)
 {
-	return crypto_aead_setkey(private, key, keylen);
+	struct aead_tfm *tfm = private;
+	int err;
+
+	err = crypto_aead_setkey(tfm->aead, key, keylen);
+	tfm->has_key = !err;
+
+	return err;
 }
 
 static void aead_sock_destruct(struct sock *sk)
@@ -757,12 +884,14 @@ static void aead_sock_destruct(struct sock *sk)
 	af_alg_release_parent(sk);
 }
 
-static int aead_accept_parent(void *private, struct sock *sk)
+static int aead_accept_parent_nokey(void *private, struct sock *sk)
 {
 	struct aead_ctx *ctx;
 	struct alg_sock *ask = alg_sk(sk);
-	unsigned int len = sizeof(*ctx) + crypto_aead_reqsize(private);
-	unsigned int ivlen = crypto_aead_ivsize(private);
+	struct aead_tfm *tfm = private;
+	struct crypto_aead *aead = tfm->aead;
+	unsigned int len = sizeof(*ctx) + crypto_aead_reqsize(aead);
+	unsigned int ivlen = crypto_aead_ivsize(aead);
 
 	ctx = sock_kmalloc(sk, len, GFP_KERNEL);
 	if (!ctx)
@@ -789,7 +918,7 @@ static int aead_accept_parent(void *private, struct sock *sk)
 
 	ask->private = ctx;
 
-	aead_request_set_tfm(&ctx->aead_req, private);
+	aead_request_set_tfm(&ctx->aead_req, aead);
 	aead_request_set_callback(&ctx->aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
 				  af_alg_complete, &ctx->completion);
 
@@ -798,13 +927,25 @@ static int aead_accept_parent(void *private, struct sock *sk)
 	return 0;
 }
 
+static int aead_accept_parent(void *private, struct sock *sk)
+{
+	struct aead_tfm *tfm = private;
+
+	if (!tfm->has_key)
+		return -ENOKEY;
+
+	return aead_accept_parent_nokey(private, sk);
+}
+
 static const struct af_alg_type algif_type_aead = {
 	.bind		=	aead_bind,
 	.release	=	aead_release,
 	.setkey		=	aead_setkey,
 	.setauthsize	=	aead_setauthsize,
 	.accept		=	aead_accept_parent,
+	.accept_nokey	=	aead_accept_parent_nokey,
 	.ops		=	&algif_aead_ops,
+	.ops_nokey	=	&algif_aead_ops_nokey,
 	.name		=	"aead",
 	.owner		=	THIS_MODULE
 };