blob: 6d37bab35c8fcc42e77d1fc081a87d61658751c3 [file] [log] [blame]
// SPDX-License-Identifier: GPL-2.0
/*
* Copyright (c) 2023 Oracle and/or its affiliates.
*
* KUnit test of the handshake upcall mechanism.
*/
#include <kunit/test.h>
#include <kunit/visibility.h>
#include <linux/kernel.h>
#include <net/sock.h>
#include <net/genetlink.h>
#include <net/netns/generic.h>
#include <uapi/linux/handshake.h>
#include "handshake.h"
MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
static int test_accept_func(struct handshake_req *req, struct genl_info *info,
int fd)
{
return 0;
}
static void test_done_func(struct handshake_req *req, unsigned int status,
struct genl_info *info)
{
}
struct handshake_req_alloc_test_param {
const char *desc;
struct handshake_proto *proto;
gfp_t gfp;
bool expect_success;
};
static struct handshake_proto handshake_req_alloc_proto_2 = {
.hp_handler_class = HANDSHAKE_HANDLER_CLASS_NONE,
};
static struct handshake_proto handshake_req_alloc_proto_3 = {
.hp_handler_class = HANDSHAKE_HANDLER_CLASS_MAX,
};
static struct handshake_proto handshake_req_alloc_proto_4 = {
.hp_handler_class = HANDSHAKE_HANDLER_CLASS_TLSHD,
};
static struct handshake_proto handshake_req_alloc_proto_5 = {
.hp_handler_class = HANDSHAKE_HANDLER_CLASS_TLSHD,
.hp_accept = test_accept_func,
};
static struct handshake_proto handshake_req_alloc_proto_6 = {
.hp_handler_class = HANDSHAKE_HANDLER_CLASS_TLSHD,
.hp_privsize = UINT_MAX,
.hp_accept = test_accept_func,
.hp_done = test_done_func,
};
static struct handshake_proto handshake_req_alloc_proto_good = {
.hp_handler_class = HANDSHAKE_HANDLER_CLASS_TLSHD,
.hp_accept = test_accept_func,
.hp_done = test_done_func,
};
static const
struct handshake_req_alloc_test_param handshake_req_alloc_params[] = {
{
.desc = "handshake_req_alloc NULL proto",
.proto = NULL,
.gfp = GFP_KERNEL,
.expect_success = false,
},
{
.desc = "handshake_req_alloc CLASS_NONE",
.proto = &handshake_req_alloc_proto_2,
.gfp = GFP_KERNEL,
.expect_success = false,
},
{
.desc = "handshake_req_alloc CLASS_MAX",
.proto = &handshake_req_alloc_proto_3,
.gfp = GFP_KERNEL,
.expect_success = false,
},
{
.desc = "handshake_req_alloc no callbacks",
.proto = &handshake_req_alloc_proto_4,
.gfp = GFP_KERNEL,
.expect_success = false,
},
{
.desc = "handshake_req_alloc no done callback",
.proto = &handshake_req_alloc_proto_5,
.gfp = GFP_KERNEL,
.expect_success = false,
},
{
.desc = "handshake_req_alloc excessive privsize",
.proto = &handshake_req_alloc_proto_6,
.gfp = GFP_KERNEL | __GFP_NOWARN,
.expect_success = false,
},
{
.desc = "handshake_req_alloc all good",
.proto = &handshake_req_alloc_proto_good,
.gfp = GFP_KERNEL,
.expect_success = true,
},
};
static void
handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param *param,
char *desc)
{
strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
}
/* Creates the function handshake_req_alloc_gen_params */
KUNIT_ARRAY_PARAM(handshake_req_alloc, handshake_req_alloc_params,
handshake_req_alloc_get_desc);
static void handshake_req_alloc_case(struct kunit *test)
{
const struct handshake_req_alloc_test_param *param = test->param_value;
struct handshake_req *result;
/* Arrange */
/* Act */
result = handshake_req_alloc(param->proto, param->gfp);
/* Assert */
if (param->expect_success)
KUNIT_EXPECT_NOT_NULL(test, result);
else
KUNIT_EXPECT_NULL(test, result);
kfree(result);
}
static void handshake_req_submit_test1(struct kunit *test)
{
struct socket *sock;
int err, result;
/* Arrange */
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
/* Act */
result = handshake_req_submit(sock, NULL, GFP_KERNEL);
/* Assert */
KUNIT_EXPECT_EQ(test, result, -EINVAL);
sock_release(sock);
}
static void handshake_req_submit_test2(struct kunit *test)
{
struct handshake_req *req;
int result;
/* Arrange */
req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
/* Act */
result = handshake_req_submit(NULL, req, GFP_KERNEL);
/* Assert */
KUNIT_EXPECT_EQ(test, result, -EINVAL);
/* handshake_req_submit() destroys @req on error */
}
static void handshake_req_submit_test3(struct kunit *test)
{
struct handshake_req *req;
struct socket *sock;
int err, result;
/* Arrange */
req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
sock->file = NULL;
/* Act */
result = handshake_req_submit(sock, req, GFP_KERNEL);
/* Assert */
KUNIT_EXPECT_EQ(test, result, -EINVAL);
/* handshake_req_submit() destroys @req on error */
sock_release(sock);
}
static void handshake_req_submit_test4(struct kunit *test)
{
struct handshake_req *req, *result;
struct socket *sock;
struct file *filp;
int err;
/* Arrange */
req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
KUNIT_ASSERT_NOT_NULL(test, sock->sk);
sock->file = filp;
err = handshake_req_submit(sock, req, GFP_KERNEL);
KUNIT_ASSERT_EQ(test, err, 0);
/* Act */
result = handshake_req_hash_lookup(sock->sk);
/* Assert */
KUNIT_EXPECT_NOT_NULL(test, result);
KUNIT_EXPECT_PTR_EQ(test, req, result);
handshake_req_cancel(sock->sk);
sock_release(sock);
}
static void handshake_req_submit_test5(struct kunit *test)
{
struct handshake_req *req;
struct handshake_net *hn;
struct socket *sock;
struct file *filp;
struct net *net;
int saved, err;
/* Arrange */
req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
KUNIT_ASSERT_NOT_NULL(test, sock->sk);
sock->file = filp;
net = sock_net(sock->sk);
hn = handshake_pernet(net);
KUNIT_ASSERT_NOT_NULL(test, hn);
saved = hn->hn_pending;
hn->hn_pending = hn->hn_pending_max + 1;
/* Act */
err = handshake_req_submit(sock, req, GFP_KERNEL);
/* Assert */
KUNIT_EXPECT_EQ(test, err, -EAGAIN);
sock_release(sock);
hn->hn_pending = saved;
}
static void handshake_req_submit_test6(struct kunit *test)
{
struct handshake_req *req1, *req2;
struct socket *sock;
struct file *filp;
int err;
/* Arrange */
req1 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req1);
req2 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req2);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
KUNIT_ASSERT_NOT_NULL(test, sock->sk);
sock->file = filp;
/* Act */
err = handshake_req_submit(sock, req1, GFP_KERNEL);
KUNIT_ASSERT_EQ(test, err, 0);
err = handshake_req_submit(sock, req2, GFP_KERNEL);
/* Assert */
KUNIT_EXPECT_EQ(test, err, -EBUSY);
handshake_req_cancel(sock->sk);
sock_release(sock);
}
static void handshake_req_cancel_test1(struct kunit *test)
{
struct handshake_req *req;
struct socket *sock;
struct file *filp;
bool result;
int err;
/* Arrange */
req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
sock->file = filp;
err = handshake_req_submit(sock, req, GFP_KERNEL);
KUNIT_ASSERT_EQ(test, err, 0);
/* NB: handshake_req hasn't been accepted */
/* Act */
result = handshake_req_cancel(sock->sk);
/* Assert */
KUNIT_EXPECT_TRUE(test, result);
sock_release(sock);
}
static void handshake_req_cancel_test2(struct kunit *test)
{
struct handshake_req *req, *next;
struct handshake_net *hn;
struct socket *sock;
struct file *filp;
struct net *net;
bool result;
int err;
/* Arrange */
req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
sock->file = filp;
err = handshake_req_submit(sock, req, GFP_KERNEL);
KUNIT_ASSERT_EQ(test, err, 0);
net = sock_net(sock->sk);
hn = handshake_pernet(net);
KUNIT_ASSERT_NOT_NULL(test, hn);
/* Pretend to accept this request */
next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
KUNIT_ASSERT_PTR_EQ(test, req, next);
/* Act */
result = handshake_req_cancel(sock->sk);
/* Assert */
KUNIT_EXPECT_TRUE(test, result);
sock_release(sock);
}
static void handshake_req_cancel_test3(struct kunit *test)
{
struct handshake_req *req, *next;
struct handshake_net *hn;
struct socket *sock;
struct file *filp;
struct net *net;
bool result;
int err;
/* Arrange */
req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
sock->file = filp;
err = handshake_req_submit(sock, req, GFP_KERNEL);
KUNIT_ASSERT_EQ(test, err, 0);
net = sock_net(sock->sk);
hn = handshake_pernet(net);
KUNIT_ASSERT_NOT_NULL(test, hn);
/* Pretend to accept this request */
next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
KUNIT_ASSERT_PTR_EQ(test, req, next);
/* Pretend to complete this request */
handshake_complete(next, -ETIMEDOUT, NULL);
/* Act */
result = handshake_req_cancel(sock->sk);
/* Assert */
KUNIT_EXPECT_FALSE(test, result);
sock_release(sock);
}
static struct handshake_req *handshake_req_destroy_test;
static void test_destroy_func(struct handshake_req *req)
{
handshake_req_destroy_test = req;
}
static struct handshake_proto handshake_req_alloc_proto_destroy = {
.hp_handler_class = HANDSHAKE_HANDLER_CLASS_TLSHD,
.hp_accept = test_accept_func,
.hp_done = test_done_func,
.hp_destroy = test_destroy_func,
};
static void handshake_req_destroy_test1(struct kunit *test)
{
struct handshake_req *req;
struct socket *sock;
struct file *filp;
int err;
/* Arrange */
handshake_req_destroy_test = NULL;
req = handshake_req_alloc(&handshake_req_alloc_proto_destroy, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, req);
err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
&sock, 1);
KUNIT_ASSERT_EQ(test, err, 0);
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
sock->file = filp;
err = handshake_req_submit(sock, req, GFP_KERNEL);
KUNIT_ASSERT_EQ(test, err, 0);
handshake_req_cancel(sock->sk);
/* Act */
sock_release(sock);
/* Assert */
KUNIT_EXPECT_PTR_EQ(test, handshake_req_destroy_test, req);
}
static struct kunit_case handshake_api_test_cases[] = {
{
.name = "req_alloc API fuzzing",
.run_case = handshake_req_alloc_case,
.generate_params = handshake_req_alloc_gen_params,
},
{
.name = "req_submit NULL req arg",
.run_case = handshake_req_submit_test1,
},
{
.name = "req_submit NULL sock arg",
.run_case = handshake_req_submit_test2,
},
{
.name = "req_submit NULL sock->file",
.run_case = handshake_req_submit_test3,
},
{
.name = "req_lookup works",
.run_case = handshake_req_submit_test4,
},
{
.name = "req_submit max pending",
.run_case = handshake_req_submit_test5,
},
{
.name = "req_submit multiple",
.run_case = handshake_req_submit_test6,
},
{
.name = "req_cancel before accept",
.run_case = handshake_req_cancel_test1,
},
{
.name = "req_cancel after accept",
.run_case = handshake_req_cancel_test2,
},
{
.name = "req_cancel after done",
.run_case = handshake_req_cancel_test3,
},
{
.name = "req_destroy works",
.run_case = handshake_req_destroy_test1,
},
{}
};
static struct kunit_suite handshake_api_suite = {
.name = "Handshake API tests",
.test_cases = handshake_api_test_cases,
};
kunit_test_suites(&handshake_api_suite);
MODULE_DESCRIPTION("Test handshake upcall API functions");
MODULE_LICENSE("GPL");