blob: 5fd6177189915b6a3a09f532d0536220de78dfb7 [file] [log] [blame]
// SPDX-License-Identifier: GPL-2.0
#include <sys/un.h>
#include "test_progs.h"
#include "connect_unix_prog.skel.h"
#include "sendmsg_unix_prog.skel.h"
#include "recvmsg_unix_prog.skel.h"
#include "getsockname_unix_prog.skel.h"
#include "getpeername_unix_prog.skel.h"
#include "network_helpers.h"
#define SERVUN_ADDRESS "bpf_cgroup_unix_test"
#define SERVUN_REWRITE_ADDRESS "bpf_cgroup_unix_test_rewrite"
#define SRCUN_ADDRESS "bpf_cgroup_unix_test_src"
enum sock_addr_test_type {
SOCK_ADDR_TEST_BIND,
SOCK_ADDR_TEST_CONNECT,
SOCK_ADDR_TEST_SENDMSG,
SOCK_ADDR_TEST_RECVMSG,
SOCK_ADDR_TEST_GETSOCKNAME,
SOCK_ADDR_TEST_GETPEERNAME,
};
typedef void *(*load_fn)(int cgroup_fd);
typedef void (*destroy_fn)(void *skel);
struct sock_addr_test {
enum sock_addr_test_type type;
const char *name;
/* BPF prog properties */
load_fn loadfn;
destroy_fn destroyfn;
/* Socket properties */
int socket_family;
int socket_type;
/* IP:port pairs for BPF prog to override */
const char *requested_addr;
unsigned short requested_port;
const char *expected_addr;
unsigned short expected_port;
const char *expected_src_addr;
};
static void *connect_unix_prog_load(int cgroup_fd)
{
struct connect_unix_prog *skel;
skel = connect_unix_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "skel_open"))
goto cleanup;
skel->links.connect_unix_prog = bpf_program__attach_cgroup(
skel->progs.connect_unix_prog, cgroup_fd);
if (!ASSERT_OK_PTR(skel->links.connect_unix_prog, "prog_attach"))
goto cleanup;
return skel;
cleanup:
connect_unix_prog__destroy(skel);
return NULL;
}
static void connect_unix_prog_destroy(void *skel)
{
connect_unix_prog__destroy(skel);
}
static void *sendmsg_unix_prog_load(int cgroup_fd)
{
struct sendmsg_unix_prog *skel;
skel = sendmsg_unix_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "skel_open"))
goto cleanup;
skel->links.sendmsg_unix_prog = bpf_program__attach_cgroup(
skel->progs.sendmsg_unix_prog, cgroup_fd);
if (!ASSERT_OK_PTR(skel->links.sendmsg_unix_prog, "prog_attach"))
goto cleanup;
return skel;
cleanup:
sendmsg_unix_prog__destroy(skel);
return NULL;
}
static void sendmsg_unix_prog_destroy(void *skel)
{
sendmsg_unix_prog__destroy(skel);
}
static void *recvmsg_unix_prog_load(int cgroup_fd)
{
struct recvmsg_unix_prog *skel;
skel = recvmsg_unix_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "skel_open"))
goto cleanup;
skel->links.recvmsg_unix_prog = bpf_program__attach_cgroup(
skel->progs.recvmsg_unix_prog, cgroup_fd);
if (!ASSERT_OK_PTR(skel->links.recvmsg_unix_prog, "prog_attach"))
goto cleanup;
return skel;
cleanup:
recvmsg_unix_prog__destroy(skel);
return NULL;
}
static void recvmsg_unix_prog_destroy(void *skel)
{
recvmsg_unix_prog__destroy(skel);
}
static void *getsockname_unix_prog_load(int cgroup_fd)
{
struct getsockname_unix_prog *skel;
skel = getsockname_unix_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "skel_open"))
goto cleanup;
skel->links.getsockname_unix_prog = bpf_program__attach_cgroup(
skel->progs.getsockname_unix_prog, cgroup_fd);
if (!ASSERT_OK_PTR(skel->links.getsockname_unix_prog, "prog_attach"))
goto cleanup;
return skel;
cleanup:
getsockname_unix_prog__destroy(skel);
return NULL;
}
static void getsockname_unix_prog_destroy(void *skel)
{
getsockname_unix_prog__destroy(skel);
}
static void *getpeername_unix_prog_load(int cgroup_fd)
{
struct getpeername_unix_prog *skel;
skel = getpeername_unix_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "skel_open"))
goto cleanup;
skel->links.getpeername_unix_prog = bpf_program__attach_cgroup(
skel->progs.getpeername_unix_prog, cgroup_fd);
if (!ASSERT_OK_PTR(skel->links.getpeername_unix_prog, "prog_attach"))
goto cleanup;
return skel;
cleanup:
getpeername_unix_prog__destroy(skel);
return NULL;
}
static void getpeername_unix_prog_destroy(void *skel)
{
getpeername_unix_prog__destroy(skel);
}
static struct sock_addr_test tests[] = {
{
SOCK_ADDR_TEST_CONNECT,
"connect_unix",
connect_unix_prog_load,
connect_unix_prog_destroy,
AF_UNIX,
SOCK_STREAM,
SERVUN_ADDRESS,
0,
SERVUN_REWRITE_ADDRESS,
0,
NULL,
},
{
SOCK_ADDR_TEST_SENDMSG,
"sendmsg_unix",
sendmsg_unix_prog_load,
sendmsg_unix_prog_destroy,
AF_UNIX,
SOCK_DGRAM,
SERVUN_ADDRESS,
0,
SERVUN_REWRITE_ADDRESS,
0,
NULL,
},
{
SOCK_ADDR_TEST_RECVMSG,
"recvmsg_unix-dgram",
recvmsg_unix_prog_load,
recvmsg_unix_prog_destroy,
AF_UNIX,
SOCK_DGRAM,
SERVUN_REWRITE_ADDRESS,
0,
SERVUN_REWRITE_ADDRESS,
0,
SERVUN_ADDRESS,
},
{
SOCK_ADDR_TEST_RECVMSG,
"recvmsg_unix-stream",
recvmsg_unix_prog_load,
recvmsg_unix_prog_destroy,
AF_UNIX,
SOCK_STREAM,
SERVUN_REWRITE_ADDRESS,
0,
SERVUN_REWRITE_ADDRESS,
0,
SERVUN_ADDRESS,
},
{
SOCK_ADDR_TEST_GETSOCKNAME,
"getsockname_unix",
getsockname_unix_prog_load,
getsockname_unix_prog_destroy,
AF_UNIX,
SOCK_STREAM,
SERVUN_ADDRESS,
0,
SERVUN_REWRITE_ADDRESS,
0,
NULL,
},
{
SOCK_ADDR_TEST_GETPEERNAME,
"getpeername_unix",
getpeername_unix_prog_load,
getpeername_unix_prog_destroy,
AF_UNIX,
SOCK_STREAM,
SERVUN_ADDRESS,
0,
SERVUN_REWRITE_ADDRESS,
0,
NULL,
},
};
typedef int (*info_fn)(int, struct sockaddr *, socklen_t *);
static int cmp_addr(const struct sockaddr_storage *addr1, socklen_t addr1_len,
const struct sockaddr_storage *addr2, socklen_t addr2_len,
bool cmp_port)
{
const struct sockaddr_in *four1, *four2;
const struct sockaddr_in6 *six1, *six2;
const struct sockaddr_un *un1, *un2;
if (addr1->ss_family != addr2->ss_family)
return -1;
if (addr1_len != addr2_len)
return -1;
if (addr1->ss_family == AF_INET) {
four1 = (const struct sockaddr_in *)addr1;
four2 = (const struct sockaddr_in *)addr2;
return !((four1->sin_port == four2->sin_port || !cmp_port) &&
four1->sin_addr.s_addr == four2->sin_addr.s_addr);
} else if (addr1->ss_family == AF_INET6) {
six1 = (const struct sockaddr_in6 *)addr1;
six2 = (const struct sockaddr_in6 *)addr2;
return !((six1->sin6_port == six2->sin6_port || !cmp_port) &&
!memcmp(&six1->sin6_addr, &six2->sin6_addr,
sizeof(struct in6_addr)));
} else if (addr1->ss_family == AF_UNIX) {
un1 = (const struct sockaddr_un *)addr1;
un2 = (const struct sockaddr_un *)addr2;
return memcmp(un1, un2, addr1_len);
}
return -1;
}
static int cmp_sock_addr(info_fn fn, int sock1,
const struct sockaddr_storage *addr2,
socklen_t addr2_len, bool cmp_port)
{
struct sockaddr_storage addr1;
socklen_t len1 = sizeof(addr1);
memset(&addr1, 0, len1);
if (fn(sock1, (struct sockaddr *)&addr1, (socklen_t *)&len1) != 0)
return -1;
return cmp_addr(&addr1, len1, addr2, addr2_len, cmp_port);
}
static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2,
socklen_t addr2_len, bool cmp_port)
{
return cmp_sock_addr(getsockname, sock1, addr2, addr2_len, cmp_port);
}
static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2,
socklen_t addr2_len, bool cmp_port)
{
return cmp_sock_addr(getpeername, sock1, addr2, addr2_len, cmp_port);
}
static void test_bind(struct sock_addr_test *test)
{
struct sockaddr_storage expected_addr;
socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
int serv = -1, client = -1, err;
serv = start_server(test->socket_family, test->socket_type,
test->requested_addr, test->requested_port, 0);
if (!ASSERT_GE(serv, 0, "start_server"))
goto cleanup;
err = make_sockaddr(test->socket_family,
test->expected_addr, test->expected_port,
&expected_addr, &expected_addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
goto cleanup;
/* Try to connect to server just in case */
client = connect_to_addr(&expected_addr, expected_addr_len, test->socket_type);
if (!ASSERT_GE(client, 0, "connect_to_addr"))
goto cleanup;
cleanup:
if (client != -1)
close(client);
if (serv != -1)
close(serv);
}
static void test_connect(struct sock_addr_test *test)
{
struct sockaddr_storage addr, expected_addr, expected_src_addr;
socklen_t addr_len = sizeof(struct sockaddr_storage),
expected_addr_len = sizeof(struct sockaddr_storage),
expected_src_addr_len = sizeof(struct sockaddr_storage);
int serv = -1, client = -1, err;
serv = start_server(test->socket_family, test->socket_type,
test->expected_addr, test->expected_port, 0);
if (!ASSERT_GE(serv, 0, "start_server"))
goto cleanup;
err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
&addr, &addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
client = connect_to_addr(&addr, addr_len, test->socket_type);
if (!ASSERT_GE(client, 0, "connect_to_addr"))
goto cleanup;
err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port,
&expected_addr, &expected_addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
if (test->expected_src_addr) {
err = make_sockaddr(test->socket_family, test->expected_src_addr, 0,
&expected_src_addr, &expected_src_addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
}
err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
goto cleanup;
if (test->expected_src_addr) {
err = cmp_local_addr(client, &expected_src_addr, expected_src_addr_len, false);
if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
goto cleanup;
}
cleanup:
if (client != -1)
close(client);
if (serv != -1)
close(serv);
}
static void test_xmsg(struct sock_addr_test *test)
{
struct sockaddr_storage addr, src_addr;
socklen_t addr_len = sizeof(struct sockaddr_storage),
src_addr_len = sizeof(struct sockaddr_storage);
struct msghdr hdr;
struct iovec iov;
char data = 'a';
int serv = -1, client = -1, err;
/* Unlike the other tests, here we test that we can rewrite the src addr
* with a recvmsg() hook.
*/
serv = start_server(test->socket_family, test->socket_type,
test->expected_addr, test->expected_port, 0);
if (!ASSERT_GE(serv, 0, "start_server"))
goto cleanup;
client = socket(test->socket_family, test->socket_type, 0);
if (!ASSERT_GE(client, 0, "socket"))
goto cleanup;
/* AF_UNIX sockets have to be bound to something to trigger the recvmsg bpf program. */
if (test->socket_family == AF_UNIX) {
err = make_sockaddr(AF_UNIX, SRCUN_ADDRESS, 0, &src_addr, &src_addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
err = bind(client, (const struct sockaddr *) &src_addr, src_addr_len);
if (!ASSERT_OK(err, "bind"))
goto cleanup;
}
err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
&addr, &addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
if (test->socket_type == SOCK_DGRAM) {
memset(&iov, 0, sizeof(iov));
iov.iov_base = &data;
iov.iov_len = sizeof(data);
memset(&hdr, 0, sizeof(hdr));
hdr.msg_name = (void *)&addr;
hdr.msg_namelen = addr_len;
hdr.msg_iov = &iov;
hdr.msg_iovlen = 1;
err = sendmsg(client, &hdr, 0);
if (!ASSERT_EQ(err, sizeof(data), "sendmsg"))
goto cleanup;
} else {
/* Testing with connection-oriented sockets is only valid for
* recvmsg() tests.
*/
if (!ASSERT_EQ(test->type, SOCK_ADDR_TEST_RECVMSG, "recvmsg"))
goto cleanup;
err = connect(client, (const struct sockaddr *)&addr, addr_len);
if (!ASSERT_OK(err, "connect"))
goto cleanup;
err = send(client, &data, sizeof(data), 0);
if (!ASSERT_EQ(err, sizeof(data), "send"))
goto cleanup;
err = listen(serv, 0);
if (!ASSERT_OK(err, "listen"))
goto cleanup;
err = accept(serv, NULL, NULL);
if (!ASSERT_GE(err, 0, "accept"))
goto cleanup;
close(serv);
serv = err;
}
addr_len = src_addr_len = sizeof(struct sockaddr_storage);
err = recvfrom(serv, &data, sizeof(data), 0, (struct sockaddr *) &src_addr, &src_addr_len);
if (!ASSERT_EQ(err, sizeof(data), "recvfrom"))
goto cleanup;
ASSERT_EQ(data, 'a', "data mismatch");
if (test->expected_src_addr) {
err = make_sockaddr(test->socket_family, test->expected_src_addr, 0,
&addr, &addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
err = cmp_addr(&src_addr, src_addr_len, &addr, addr_len, false);
if (!ASSERT_EQ(err, 0, "cmp_addr"))
goto cleanup;
}
cleanup:
if (client != -1)
close(client);
if (serv != -1)
close(serv);
}
static void test_getsockname(struct sock_addr_test *test)
{
struct sockaddr_storage expected_addr;
socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
int serv = -1, err;
serv = start_server(test->socket_family, test->socket_type,
test->requested_addr, test->requested_port, 0);
if (!ASSERT_GE(serv, 0, "start_server"))
goto cleanup;
err = make_sockaddr(test->socket_family,
test->expected_addr, test->expected_port,
&expected_addr, &expected_addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
goto cleanup;
cleanup:
if (serv != -1)
close(serv);
}
static void test_getpeername(struct sock_addr_test *test)
{
struct sockaddr_storage addr, expected_addr;
socklen_t addr_len = sizeof(struct sockaddr_storage),
expected_addr_len = sizeof(struct sockaddr_storage);
int serv = -1, client = -1, err;
serv = start_server(test->socket_family, test->socket_type,
test->requested_addr, test->requested_port, 0);
if (!ASSERT_GE(serv, 0, "start_server"))
goto cleanup;
err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
&addr, &addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
client = connect_to_addr(&addr, addr_len, test->socket_type);
if (!ASSERT_GE(client, 0, "connect_to_addr"))
goto cleanup;
err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port,
&expected_addr, &expected_addr_len);
if (!ASSERT_EQ(err, 0, "make_sockaddr"))
goto cleanup;
err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
goto cleanup;
cleanup:
if (client != -1)
close(client);
if (serv != -1)
close(serv);
}
void test_sock_addr(void)
{
int cgroup_fd = -1;
void *skel;
cgroup_fd = test__join_cgroup("/sock_addr");
if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup"))
goto cleanup;
for (size_t i = 0; i < ARRAY_SIZE(tests); ++i) {
struct sock_addr_test *test = &tests[i];
if (!test__start_subtest(test->name))
continue;
skel = test->loadfn(cgroup_fd);
if (!skel)
continue;
switch (test->type) {
/* Not exercised yet but we leave this code here for when the
* INET and INET6 sockaddr tests are migrated to this file in
* the future.
*/
case SOCK_ADDR_TEST_BIND:
test_bind(test);
break;
case SOCK_ADDR_TEST_CONNECT:
test_connect(test);
break;
case SOCK_ADDR_TEST_SENDMSG:
case SOCK_ADDR_TEST_RECVMSG:
test_xmsg(test);
break;
case SOCK_ADDR_TEST_GETSOCKNAME:
test_getsockname(test);
break;
case SOCK_ADDR_TEST_GETPEERNAME:
test_getpeername(test);
break;
default:
ASSERT_TRUE(false, "Unknown sock addr test type");
break;
}
test->destroyfn(skel);
}
cleanup:
if (cgroup_fd >= 0)
close(cgroup_fd);
}