1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020, Tessares SA. */ 3 /* Copyright (c) 2022, SUSE. */ 4 5 #include <test_progs.h> 6 #include "cgroup_helpers.h" 7 #include "network_helpers.h" 8 #include "mptcp_sock.skel.h" 9 10 #define NS_TEST "mptcp_ns" 11 12 #ifndef TCP_CA_NAME_MAX 13 #define TCP_CA_NAME_MAX 16 14 #endif 15 16 struct mptcp_storage { 17 __u32 invoked; 18 __u32 is_mptcp; 19 struct sock *sk; 20 __u32 token; 21 struct sock *first; 22 char ca_name[TCP_CA_NAME_MAX]; 23 }; 24 25 static int verify_tsk(int map_fd, int client_fd) 26 { 27 int err, cfd = client_fd; 28 struct mptcp_storage val; 29 30 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 31 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 32 return err; 33 34 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 35 err++; 36 37 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp")) 38 err++; 39 40 return err; 41 } 42 43 static void get_msk_ca_name(char ca_name[]) 44 { 45 size_t len; 46 int fd; 47 48 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY); 49 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control")) 50 return; 51 52 len = read(fd, ca_name, TCP_CA_NAME_MAX); 53 if (!ASSERT_GT(len, 0, "failed to read ca_name")) 54 goto err; 55 56 if (len > 0 && ca_name[len - 1] == '\n') 57 ca_name[len - 1] = '\0'; 58 59 err: 60 close(fd); 61 } 62 63 static int verify_msk(int map_fd, int client_fd, __u32 token) 64 { 65 char ca_name[TCP_CA_NAME_MAX]; 66 int err, cfd = client_fd; 67 struct mptcp_storage val; 68 69 if (!ASSERT_GT(token, 0, "invalid token")) 70 return -1; 71 72 get_msk_ca_name(ca_name); 73 74 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 75 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 76 return err; 77 78 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 79 err++; 80 81 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp")) 82 err++; 83 84 if (!ASSERT_EQ(val.token, token, "unexpected token")) 85 err++; 86 87 if (!ASSERT_EQ(val.first, val.sk, "unexpected first")) 88 err++; 89 90 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name")) 91 err++; 92 93 return err; 94 } 95 96 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp) 97 { 98 int client_fd, prog_fd, map_fd, err; 99 struct mptcp_sock *sock_skel; 100 101 sock_skel = mptcp_sock__open_and_load(); 102 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load")) 103 return -EIO; 104 105 err = mptcp_sock__attach(sock_skel); 106 if (!ASSERT_OK(err, "skel_attach")) 107 goto out; 108 109 prog_fd = bpf_program__fd(sock_skel->progs._sockops); 110 if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) { 111 err = -EIO; 112 goto out; 113 } 114 115 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map); 116 if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) { 117 err = -EIO; 118 goto out; 119 } 120 121 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); 122 if (!ASSERT_OK(err, "bpf_prog_attach")) 123 goto out; 124 125 client_fd = connect_to_fd(server_fd, 0); 126 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 127 err = -EIO; 128 goto out; 129 } 130 131 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) : 132 verify_tsk(map_fd, client_fd); 133 134 close(client_fd); 135 136 out: 137 mptcp_sock__destroy(sock_skel); 138 return err; 139 } 140 141 static void test_base(void) 142 { 143 struct nstoken *nstoken = NULL; 144 int server_fd, cgroup_fd; 145 146 cgroup_fd = test__join_cgroup("/mptcp"); 147 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 148 return; 149 150 SYS(fail, "ip netns add %s", NS_TEST); 151 SYS(fail, "ip -net %s link set dev lo up", NS_TEST); 152 153 nstoken = open_netns(NS_TEST); 154 if (!ASSERT_OK_PTR(nstoken, "open_netns")) 155 goto fail; 156 157 /* without MPTCP */ 158 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 159 if (!ASSERT_GE(server_fd, 0, "start_server")) 160 goto with_mptcp; 161 162 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp"); 163 164 close(server_fd); 165 166 with_mptcp: 167 /* with MPTCP */ 168 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0); 169 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) 170 goto fail; 171 172 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp"); 173 174 close(server_fd); 175 176 fail: 177 if (nstoken) 178 close_netns(nstoken); 179 180 SYS_NOFAIL("ip netns del " NS_TEST " &> /dev/null"); 181 182 close(cgroup_fd); 183 } 184 185 void test_mptcp(void) 186 { 187 if (test__start_subtest("base")) 188 test_base(); 189 } 190