1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020, Tessares SA. */ 3 /* Copyright (c) 2022, SUSE. */ 4 5 #include <test_progs.h> 6 #include "cgroup_helpers.h" 7 #include "network_helpers.h" 8 #include "mptcp_sock.skel.h" 9 10 #ifndef TCP_CA_NAME_MAX 11 #define TCP_CA_NAME_MAX 16 12 #endif 13 14 struct mptcp_storage { 15 __u32 invoked; 16 __u32 is_mptcp; 17 struct sock *sk; 18 __u32 token; 19 struct sock *first; 20 char ca_name[TCP_CA_NAME_MAX]; 21 }; 22 23 static int verify_tsk(int map_fd, int client_fd) 24 { 25 int err, cfd = client_fd; 26 struct mptcp_storage val; 27 28 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 29 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 30 return err; 31 32 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 33 err++; 34 35 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp")) 36 err++; 37 38 return err; 39 } 40 41 static void get_msk_ca_name(char ca_name[]) 42 { 43 size_t len; 44 int fd; 45 46 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY); 47 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control")) 48 return; 49 50 len = read(fd, ca_name, TCP_CA_NAME_MAX); 51 if (!ASSERT_GT(len, 0, "failed to read ca_name")) 52 goto err; 53 54 if (len > 0 && ca_name[len - 1] == '\n') 55 ca_name[len - 1] = '\0'; 56 57 err: 58 close(fd); 59 } 60 61 static int verify_msk(int map_fd, int client_fd, __u32 token) 62 { 63 char ca_name[TCP_CA_NAME_MAX]; 64 int err, cfd = client_fd; 65 struct mptcp_storage val; 66 67 if (!ASSERT_GT(token, 0, "invalid token")) 68 return -1; 69 70 get_msk_ca_name(ca_name); 71 72 err = bpf_map_lookup_elem(map_fd, &cfd, &val); 73 if (!ASSERT_OK(err, "bpf_map_lookup_elem")) 74 return err; 75 76 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count")) 77 err++; 78 79 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp")) 80 err++; 81 82 if (!ASSERT_EQ(val.token, token, "unexpected token")) 83 err++; 84 85 if (!ASSERT_EQ(val.first, val.sk, "unexpected first")) 86 err++; 87 88 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name")) 89 err++; 90 91 return err; 92 } 93 94 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp) 95 { 96 int client_fd, prog_fd, map_fd, err; 97 struct mptcp_sock *sock_skel; 98 99 sock_skel = mptcp_sock__open_and_load(); 100 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load")) 101 return -EIO; 102 103 err = mptcp_sock__attach(sock_skel); 104 if (!ASSERT_OK(err, "skel_attach")) 105 goto out; 106 107 prog_fd = bpf_program__fd(sock_skel->progs._sockops); 108 if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) { 109 err = -EIO; 110 goto out; 111 } 112 113 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map); 114 if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) { 115 err = -EIO; 116 goto out; 117 } 118 119 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); 120 if (!ASSERT_OK(err, "bpf_prog_attach")) 121 goto out; 122 123 client_fd = connect_to_fd(server_fd, 0); 124 if (!ASSERT_GE(client_fd, 0, "connect to fd")) { 125 err = -EIO; 126 goto out; 127 } 128 129 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) : 130 verify_tsk(map_fd, client_fd); 131 132 close(client_fd); 133 134 out: 135 mptcp_sock__destroy(sock_skel); 136 return err; 137 } 138 139 static void test_base(void) 140 { 141 int server_fd, cgroup_fd; 142 143 cgroup_fd = test__join_cgroup("/mptcp"); 144 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup")) 145 return; 146 147 /* without MPTCP */ 148 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0); 149 if (!ASSERT_GE(server_fd, 0, "start_server")) 150 goto with_mptcp; 151 152 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp"); 153 154 close(server_fd); 155 156 with_mptcp: 157 /* with MPTCP */ 158 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0); 159 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) 160 goto close_cgroup_fd; 161 162 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp"); 163 164 close(server_fd); 165 166 close_cgroup_fd: 167 close(cgroup_fd); 168 } 169 170 void test_mptcp(void) 171 { 172 if (test__start_subtest("base")) 173 test_base(); 174 } 175