| #! /usr/bin/env python3 |
| # SPDX-License-Identifier: GPL-2.0 |
| |
| import argparse |
| import ctypes |
| import errno |
| import hashlib |
| import os |
| import select |
| import signal |
| import socket |
| import subprocess |
| import sys |
| import atexit |
| from pwd import getpwuid |
| from os import stat |
| from lib.py import ip |
| |
| |
| libc = ctypes.cdll.LoadLibrary('libc.so.6') |
| setns = libc.setns |
| |
| net0 = 'net0' |
| net1 = 'net1' |
| |
| veth0 = 'veth0' |
| veth1 = 'veth1' |
| |
| # Helper function for creating a socket inside a network namespace. |
| # We need this because otherwise RDS will detect that the two TCP |
| # sockets are on the same interface and use the loop transport instead |
| # of the TCP transport. |
| def netns_socket(netns, *args): |
| u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) |
| |
| child = os.fork() |
| if child == 0: |
| # change network namespace |
| with open(f'/var/run/netns/{netns}') as f: |
| try: |
| ret = setns(f.fileno(), 0) |
| except IOError as e: |
| print(e.errno) |
| print(e) |
| |
| # create socket in target namespace |
| s = socket.socket(*args) |
| |
| # send resulting socket to parent |
| socket.send_fds(u0, [], [s.fileno()]) |
| |
| sys.exit(0) |
| |
| # receive socket from child |
| _, s, _, _ = socket.recv_fds(u1, 0, 1) |
| os.waitpid(child, 0) |
| u0.close() |
| u1.close() |
| return socket.fromfd(s[0], *args) |
| |
| def signal_handler(sig, frame): |
| print('Test timed out') |
| sys.exit(1) |
| |
| #Parse out command line arguments. We take an optional |
| # timeout parameter and an optional log output folder |
| parser = argparse.ArgumentParser(description="init script args", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| parser.add_argument("-d", "--logdir", action="store", |
| help="directory to store logs", default="/tmp") |
| parser.add_argument('--timeout', help="timeout to terminate hung test", |
| type=int, default=0) |
| parser.add_argument('-l', '--loss', help="Simulate tcp packet loss", |
| type=int, default=0) |
| parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption", |
| type=int, default=0) |
| parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication", |
| type=int, default=0) |
| args = parser.parse_args() |
| logdir=args.logdir |
| packet_loss=str(args.loss)+'%' |
| packet_corruption=str(args.corruption)+'%' |
| packet_duplicate=str(args.duplicate)+'%' |
| |
| ip(f"netns add {net0}") |
| ip(f"netns add {net1}") |
| ip(f"link add type veth") |
| |
| addrs = [ |
| # we technically don't need different port numbers, but this will |
| # help identify traffic in the network analyzer |
| ('10.0.0.1', 10000), |
| ('10.0.0.2', 20000), |
| ] |
| |
| # move interfaces to separate namespaces so they can no longer be |
| # bound directly; this prevents rds from switching over from the tcp |
| # transport to the loop transport. |
| ip(f"link set {veth0} netns {net0} up") |
| ip(f"link set {veth1} netns {net1} up") |
| |
| |
| |
| # add addresses |
| ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}") |
| ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}") |
| |
| # add routes |
| ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}") |
| ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}") |
| |
| # sanity check that our two interfaces/addresses are correctly set up |
| # and communicating by doing a single ping |
| ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}") |
| |
| # Start a packet capture on each network |
| for net in [net0, net1]: |
| tcpdump_pid = os.fork() |
| if tcpdump_pid == 0: |
| pcap = logdir+'/'+net+'.pcap' |
| subprocess.check_call(['touch', pcap]) |
| user = getpwuid(stat(pcap).st_uid).pw_name |
| ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}") |
| sys.exit(0) |
| |
| # simulate packet loss, duplication and corruption |
| for net, iface in [(net0, veth0), (net1, veth1)]: |
| ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \ |
| corrupt {packet_corruption} loss {packet_loss} duplicate \ |
| {packet_duplicate}") |
| |
| # add a timeout |
| if args.timeout > 0: |
| signal.alarm(args.timeout) |
| signal.signal(signal.SIGALRM, signal_handler) |
| |
| sockets = [ |
| netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET), |
| netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET), |
| ] |
| |
| for s, addr in zip(sockets, addrs): |
| s.bind(addr) |
| s.setblocking(0) |
| |
| fileno_to_socket = { |
| s.fileno(): s for s in sockets |
| } |
| |
| addr_to_socket = { |
| addr: s for addr, s in zip(addrs, sockets) |
| } |
| |
| socket_to_addr = { |
| s: addr for addr, s in zip(addrs, sockets) |
| } |
| |
| send_hashes = {} |
| recv_hashes = {} |
| |
| ep = select.epoll() |
| |
| for s in sockets: |
| ep.register(s, select.EPOLLRDNORM) |
| |
| n = 50000 |
| nr_send = 0 |
| nr_recv = 0 |
| |
| while nr_send < n: |
| # Send as much as we can without blocking |
| print("sending...", nr_send, nr_recv) |
| while nr_send < n: |
| send_data = hashlib.sha256( |
| f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8') |
| |
| # pseudo-random send/receive pattern |
| sender = sockets[nr_send % 2] |
| receiver = sockets[1 - (nr_send % 3) % 2] |
| |
| try: |
| sender.sendto(send_data, socket_to_addr[receiver]) |
| send_hashes.setdefault((sender.fileno(), receiver.fileno()), |
| hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8')) |
| nr_send = nr_send + 1 |
| except BlockingIOError as e: |
| break |
| except OSError as e: |
| if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]: |
| break |
| raise |
| |
| # Receive as much as we can without blocking |
| print("receiving...", nr_send, nr_recv) |
| while nr_recv < nr_send: |
| for fileno, eventmask in ep.poll(): |
| receiver = fileno_to_socket[fileno] |
| |
| if eventmask & select.EPOLLRDNORM: |
| while True: |
| try: |
| recv_data, address = receiver.recvfrom(1024) |
| sender = addr_to_socket[address] |
| recv_hashes.setdefault((sender.fileno(), |
| receiver.fileno()), hashlib.sha256()).update( |
| f'<{recv_data}>'.encode('utf-8')) |
| nr_recv = nr_recv + 1 |
| except BlockingIOError as e: |
| break |
| |
| # exercise net/rds/tcp.c:rds_tcp_sysctl_reset() |
| for net in [net0, net1]: |
| ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000") |
| ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000") |
| |
| print("done", nr_send, nr_recv) |
| |
| # the Python socket module doesn't know these |
| RDS_INFO_FIRST = 10000 |
| RDS_INFO_LAST = 10017 |
| |
| nr_success = 0 |
| nr_error = 0 |
| |
| for s in sockets: |
| for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1): |
| # Sigh, the Python socket module doesn't allow us to pass |
| # buffer lengths greater than 1024 for some reason. RDS |
| # wants multiple pages. |
| try: |
| s.getsockopt(socket.SOL_RDS, optname, 1024) |
| nr_success = nr_success + 1 |
| except OSError as e: |
| nr_error = nr_error + 1 |
| if e.errno == errno.ENOSPC: |
| # ignore |
| pass |
| |
| print(f"getsockopt(): {nr_success}/{nr_error}") |
| |
| print("Stopping network packet captures") |
| subprocess.check_call(['killall', '-q', 'tcpdump']) |
| |
| # We're done sending and receiving stuff, now let's check if what |
| # we received is what we sent. |
| for (sender, receiver), send_hash in send_hashes.items(): |
| recv_hash = recv_hashes.get((sender, receiver)) |
| |
| if recv_hash is None: |
| print("FAIL: No data received") |
| sys.exit(1) |
| |
| if send_hash.hexdigest() != recv_hash.hexdigest(): |
| print("FAIL: Send/recv mismatch") |
| print("hash expected:", send_hash.hexdigest()) |
| print("hash received:", recv_hash.hexdigest()) |
| sys.exit(1) |
| |
| print(f"{sender}/{receiver}: ok") |
| |
| print("Success") |
| sys.exit(0) |