1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020, Tessares SA. */ 3 /* Copyright (c) 2022, SUSE. */ 4 5 #include <linux/const.h> 6 #include <netinet/in.h> 7 #include <test_progs.h> 8 #include "cgroup_helpers.h" 9 #include "network_helpers.h" 10 #include "mptcp_sock.skel.h" 11 #include "mptcpify.skel.h" 12 13 #define NS_TEST "mptcp_ns" 14 15 #ifndef IPPROTO_MPTCP 16 #define IPPROTO_MPTCP 262 17 #endif 18 19 #ifndef SOL_MPTCP 20 #define SOL_MPTCP 284 21 #endif 22 #ifndef MPTCP_INFO 23 #define MPTCP_INFO 1 24 #endif 25 #ifndef MPTCP_INFO_FLAG_FALLBACK 26 #define MPTCP_INFO_FLAG_FALLBACK _BITUL(0) 27 #endif 28 #ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED 29 #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1) 30 #endif 31 32 #ifndef TCP_CA_NAME_MAX 33 #define TCP_CA_NAME_MAX 16 34 #endif 35 36 struct __mptcp_info { 37 __u8 mptcpi_subflows; 38 __u8 mptcpi_add_addr_signal; 39 __u8 mptcpi_add_addr_accepted; 40 __u8 mptcpi_subflows_max; 41 __u8 mptcpi_add_addr_signal_max; 42 __u8 mptcpi_add_addr_accepted_max; 43 __u32 mptcpi_flags; 44 __u32 mptcpi_token; 45 __u64 mptcpi_write_seq; 46 __u64 mptcpi_snd_una; 47 __u64 mptcpi_rcv_nxt; 48 __u8 mptcpi_local_addr_used; 49 __u8 mptcpi_local_addr_max; 50 __u8 mptcpi_csum_enabled; 51 __u32 mptcpi_retransmits; 52 __u64 mptcpi_bytes_retrans; 53 __u64 mptcpi_bytes_sent; 54 __u64 mptcpi_bytes_received; 55 __u64 mptcpi_bytes_acked; 56 }; 57 58 struct mptcp_storage { 59 __u32 invoked; 60 __u32 is_mptcp; 61 struct sock *sk; 62 __u32 token; 63 struct sock *first; 64 char ca_name[TCP_CA_NAME_MAX]; 65 }; 66 67 static struct nstoken *create_netns(void) 68 { 69 SYS(fail, "ip netns add %s", NS_TEST); 70 SYS(fail, "ip -net %s link set dev lo up", NS_TEST); 71 72 return open_netns(NS_TEST); 73 fail: 74 return NULL; 75 } 76 77 static void cleanup_netns(struct nstoken *nstoken) 78 { 79 if (nstoken) 80 close_netns(nstoken); 81 82 SYS_NOFAIL("ip netns del %s &> /dev/null", NS_TEST); 83 } 84 85 static int verify_tsk(int map_fd, int client_fd) 86 { 87 int err, cfd = client_fd; 88 struct mptcp_storage val; 89 90 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 91 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 92 return err; 93 94 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 95 err++; 96 97 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp")) 98 err++; 99 100 return err; 101 } 102 103 static void get_msk_ca_name(char ca_name[]) 104 { 105 size_t len; 106 int fd; 107 108 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY); 109 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control")) 110 return; 111 112 len = read(fd, ca_name, TCP_CA_NAME_MAX); 113 if (!ASSERT_GT(len, 0, "failed to read ca_name")) 114 goto err; 115 116 if (len > 0 && ca_name[len - 1] == '\n') 117 ca_name[len - 1] = '\0'; 118 119 err: 120 close(fd); 121 } 122 123 static int verify_msk(int map_fd, int client_fd, __u32 token) 124 { 125 char ca_name[TCP_CA_NAME_MAX]; 126 int err, cfd = client_fd; 127 struct mptcp_storage val; 128 129 if (!ASSERT_GT(token, 0, "invalid token")) 130 return -1; 131 132 get_msk_ca_name(ca_name); 133 134 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 135 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 136 return err; 137 138 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 139 err++; 140 141 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp")) 142 err++; 143 144 if (!ASSERT_EQ(val.token, token, "unexpected token")) 145 err++; 146 147 if (!ASSERT_EQ(val.first, val.sk, "unexpected first")) 148 err++; 149 150 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name")) 151 err++; 152 153 return err; 154 } 155 156 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp) 157 { 158 int client_fd, prog_fd, map_fd, err; 159 struct mptcp_sock *sock_skel; 160 161 sock_skel = mptcp_sock__open_and_load(); 162 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load")) 163 return libbpf_get_error(sock_skel); 164 165 err = mptcp_sock__attach(sock_skel); 166 if (!ASSERT_OK(err, "skel_attach")) 167 goto out; 168 169 prog_fd = bpf_program__fd(sock_skel->progs._sockops); 170 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map); 171 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); 172 if (!ASSERT_OK(err, "bpf_prog_attach")) 173 goto out; 174 175 client_fd = connect_to_fd(server_fd, 0); 176 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 177 err = -EIO; 178 goto out; 179 } 180 181 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) : 182 verify_tsk(map_fd, client_fd); 183 184 close(client_fd); 185 186 out: 187 mptcp_sock__destroy(sock_skel); 188 return err; 189 } 190 191 static void test_base(void) 192 { 193 struct nstoken *nstoken = NULL; 194 int server_fd, cgroup_fd; 195 196 cgroup_fd = test__join_cgroup("/mptcp"); 197 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 198 return; 199 200 nstoken = create_netns(); 201 if (!ASSERT_OK_PTR(nstoken, "create_netns")) 202 goto fail; 203 204 /* without MPTCP */ 205 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 206 if (!ASSERT_GE(server_fd, 0, "start_server")) 207 goto with_mptcp; 208 209 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp"); 210 211 close(server_fd); 212 213 with_mptcp: 214 /* with MPTCP */ 215 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0); 216 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) 217 goto fail; 218 219 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp"); 220 221 close(server_fd); 222 223 fail: 224 cleanup_netns(nstoken); 225 close(cgroup_fd); 226 } 227 228 static void send_byte(int fd) 229 { 230 char b = 0x55; 231 232 ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte"); 233 } 234 235 static int verify_mptcpify(int server_fd, int client_fd) 236 { 237 struct __mptcp_info info; 238 socklen_t optlen; 239 int protocol; 240 int err = 0; 241 242 optlen = sizeof(protocol); 243 if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen), 244 "getsockopt(SOL_PROTOCOL)")) 245 return -1; 246 247 if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP")) 248 err++; 249 250 optlen = sizeof(info); 251 if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen), 252 "getsockopt(MPTCP_INFO)")) 253 return -1; 254 255 if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags")) 256 err++; 257 if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK, 258 "MPTCP fallback")) 259 err++; 260 if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED, 261 "no remote key received")) 262 err++; 263 264 return err; 265 } 266 267 static int run_mptcpify(int cgroup_fd) 268 { 269 int server_fd, client_fd, err = 0; 270 struct mptcpify *mptcpify_skel; 271 272 mptcpify_skel = mptcpify__open_and_load(); 273 if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load")) 274 return libbpf_get_error(mptcpify_skel); 275 276 err = mptcpify__attach(mptcpify_skel); 277 if (!ASSERT_OK(err, "skel_attach")) 278 goto out; 279 280 /* without MPTCP */ 281 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 282 if (!ASSERT_GE(server_fd, 0, "start_server")) { 283 err = -EIO; 284 goto out; 285 } 286 287 client_fd = connect_to_fd(server_fd, 0); 288 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 289 err = -EIO; 290 goto close_server; 291 } 292 293 send_byte(client_fd); 294 295 err = verify_mptcpify(server_fd, client_fd); 296 297 close(client_fd); 298 close_server: 299 close(server_fd); 300 out: 301 mptcpify__destroy(mptcpify_skel); 302 return err; 303 } 304 305 static void test_mptcpify(void) 306 { 307 struct nstoken *nstoken = NULL; 308 int cgroup_fd; 309 310 cgroup_fd = test__join_cgroup("/mptcpify"); 311 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 312 return; 313 314 nstoken = create_netns(); 315 if (!ASSERT_OK_PTR(nstoken, "create_netns")) 316 goto fail; 317 318 ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify"); 319 320 fail: 321 cleanup_netns(nstoken); 322 close(cgroup_fd); 323 } 324 325 void test_mptcp(void) 326 { 327 if (test__start_subtest("base")) 328 test_base(); 329 if (test__start_subtest("mptcpify")) 330 test_mptcpify(); 331 } 332