| // SPDX-License-Identifier: GPL-2.0 |
| /* Author: Dmitry Safonov <dima@arista.com> */ |
| #include <inttypes.h> |
| #include "aolib.h" |
| |
| static union tcp_addr local_addr; |
| |
| static void __setup_lo_intf(const char *lo_intf, |
| const char *addr_str, uint8_t prefix) |
| { |
| if (inet_pton(TEST_FAMILY, addr_str, &local_addr) != 1) |
| test_error("Can't convert local ip address"); |
| |
| if (ip_addr_add(lo_intf, TEST_FAMILY, local_addr, prefix)) |
| test_error("Failed to add %s ip address", lo_intf); |
| |
| if (link_set_up(lo_intf)) |
| test_error("Failed to bring %s up", lo_intf); |
| } |
| |
| static void setup_lo_intf(const char *lo_intf) |
| { |
| #ifdef IPV6_TEST |
| __setup_lo_intf(lo_intf, "::1", 128); |
| #else |
| __setup_lo_intf(lo_intf, "127.0.0.1", 8); |
| #endif |
| } |
| |
| static void tcp_self_connect(const char *tst, unsigned int port, |
| bool different_keyids, bool check_restore) |
| { |
| uint64_t before_challenge_ack, after_challenge_ack; |
| uint64_t before_syn_challenge, after_syn_challenge; |
| struct tcp_ao_counters before_ao, after_ao; |
| uint64_t before_aogood, after_aogood; |
| struct netstat *ns_before, *ns_after; |
| const size_t nr_packets = 20; |
| struct tcp_ao_repair ao_img; |
| struct tcp_sock_state img; |
| sockaddr_af addr; |
| int sk; |
| |
| tcp_addr_to_sockaddr_in(&addr, &local_addr, htons(port)); |
| |
| sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); |
| if (sk < 0) |
| test_error("socket()"); |
| |
| if (different_keyids) { |
| if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 5, 7)) |
| test_error("setsockopt(TCP_AO_ADD_KEY)"); |
| if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 7, 5)) |
| test_error("setsockopt(TCP_AO_ADD_KEY)"); |
| } else { |
| if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 100, 100)) |
| test_error("setsockopt(TCP_AO_ADD_KEY)"); |
| } |
| |
| if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0) |
| test_error("bind()"); |
| |
| ns_before = netstat_read(); |
| before_aogood = netstat_get(ns_before, "TCPAOGood", NULL); |
| before_challenge_ack = netstat_get(ns_before, "TCPChallengeACK", NULL); |
| before_syn_challenge = netstat_get(ns_before, "TCPSYNChallenge", NULL); |
| if (test_get_tcp_ao_counters(sk, &before_ao)) |
| test_error("test_get_tcp_ao_counters()"); |
| |
| if (__test_connect_socket(sk, "lo", (struct sockaddr *)&addr, |
| sizeof(addr), TEST_TIMEOUT_SEC) < 0) { |
| ns_after = netstat_read(); |
| netstat_print_diff(ns_before, ns_after); |
| test_error("failed to connect()"); |
| } |
| |
| if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) { |
| test_fail("%s: tcp connection verify failed", tst); |
| close(sk); |
| return; |
| } |
| |
| ns_after = netstat_read(); |
| after_aogood = netstat_get(ns_after, "TCPAOGood", NULL); |
| after_challenge_ack = netstat_get(ns_after, "TCPChallengeACK", NULL); |
| after_syn_challenge = netstat_get(ns_after, "TCPSYNChallenge", NULL); |
| if (test_get_tcp_ao_counters(sk, &after_ao)) |
| test_error("test_get_tcp_ao_counters()"); |
| if (!check_restore) { |
| /* to debug: netstat_print_diff(ns_before, ns_after); */ |
| netstat_free(ns_before); |
| } |
| netstat_free(ns_after); |
| |
| if (after_aogood <= before_aogood) { |
| test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu", |
| tst, after_aogood, before_aogood); |
| close(sk); |
| return; |
| } |
| if (after_challenge_ack <= before_challenge_ack || |
| after_syn_challenge <= before_syn_challenge) { |
| /* |
| * It's also meant to test simultaneous open, so check |
| * these counters as well. |
| */ |
| test_fail("%s: Didn't challenge SYN or ACK: %zu <= %zu OR %zu <= %zu", |
| tst, after_challenge_ack, before_challenge_ack, |
| after_syn_challenge, before_syn_challenge); |
| close(sk); |
| return; |
| } |
| |
| if (test_tcp_ao_counters_cmp(tst, &before_ao, &after_ao, TEST_CNT_GOOD)) { |
| close(sk); |
| return; |
| } |
| |
| if (!check_restore) { |
| test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64, |
| tst, before_aogood, after_aogood); |
| close(sk); |
| return; |
| } |
| |
| test_enable_repair(sk); |
| test_sock_checkpoint(sk, &img, &addr); |
| #ifdef IPV6_TEST |
| addr.sin6_port = htons(port + 1); |
| #else |
| addr.sin_port = htons(port + 1); |
| #endif |
| test_ao_checkpoint(sk, &ao_img); |
| test_kill_sk(sk); |
| |
| sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); |
| if (sk < 0) |
| test_error("socket()"); |
| |
| test_enable_repair(sk); |
| __test_sock_restore(sk, "lo", &img, &addr, &addr, sizeof(addr)); |
| if (different_keyids) { |
| if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, |
| local_addr, -1, 7, 5)) |
| test_error("setsockopt(TCP_AO_ADD_KEY)"); |
| if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, |
| local_addr, -1, 5, 7)) |
| test_error("setsockopt(TCP_AO_ADD_KEY)"); |
| } else { |
| if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, |
| local_addr, -1, 100, 100)) |
| test_error("setsockopt(TCP_AO_ADD_KEY)"); |
| } |
| test_ao_restore(sk, &ao_img); |
| test_disable_repair(sk); |
| test_sock_state_free(&img); |
| if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) { |
| test_fail("%s: tcp connection verify failed", tst); |
| close(sk); |
| return; |
| } |
| ns_after = netstat_read(); |
| after_aogood = netstat_get(ns_after, "TCPAOGood", NULL); |
| /* to debug: netstat_print_diff(ns_before, ns_after); */ |
| netstat_free(ns_before); |
| netstat_free(ns_after); |
| close(sk); |
| if (after_aogood <= before_aogood) { |
| test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu", |
| tst, after_aogood, before_aogood); |
| return; |
| } |
| test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64, |
| tst, before_aogood, after_aogood); |
| } |
| |
| static void *client_fn(void *arg) |
| { |
| unsigned int port = test_server_port; |
| |
| setup_lo_intf("lo"); |
| |
| tcp_self_connect("self-connect(same keyids)", port++, false, false); |
| tcp_self_connect("self-connect(different keyids)", port++, true, false); |
| tcp_self_connect("self-connect(restore)", port, false, true); |
| port += 2; |
| tcp_self_connect("self-connect(restore, different keyids)", port, true, true); |
| port += 2; |
| |
| return NULL; |
| } |
| |
| int main(int argc, char *argv[]) |
| { |
| test_init(4, client_fn, NULL); |
| return 0; |
| } |