1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2020, Tessares SA. */
3 /* Copyright (c) 2022, SUSE. */
4 
5 #include <test_progs.h>
6 #include "cgroup_helpers.h"
7 #include "network_helpers.h"
8 #include "mptcp_sock.skel.h"
9 
10 #define NS_TEST "mptcp_ns"
11 
12 #ifndef TCP_CA_NAME_MAX
13 #define TCP_CA_NAME_MAX	16
14 #endif
15 
16 struct mptcp_storage {
17 	__u32 invoked;
18 	__u32 is_mptcp;
19 	struct sock *sk;
20 	__u32 token;
21 	struct sock *first;
22 	char ca_name[TCP_CA_NAME_MAX];
23 };
24 
25 static int verify_tsk(int map_fd, int client_fd)
26 {
27 	int err, cfd = client_fd;
28 	struct mptcp_storage val;
29 
30 	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
31 	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
32 		return err;
33 
34 	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
35 		err++;
36 
37 	if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
38 		err++;
39 
40 	return err;
41 }
42 
43 static void get_msk_ca_name(char ca_name[])
44 {
45 	size_t len;
46 	int fd;
47 
48 	fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
49 	if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
50 		return;
51 
52 	len = read(fd, ca_name, TCP_CA_NAME_MAX);
53 	if (!ASSERT_GT(len, 0, "failed to read ca_name"))
54 		goto err;
55 
56 	if (len > 0 && ca_name[len - 1] == '\n')
57 		ca_name[len - 1] = '\0';
58 
59 err:
60 	close(fd);
61 }
62 
63 static int verify_msk(int map_fd, int client_fd, __u32 token)
64 {
65 	char ca_name[TCP_CA_NAME_MAX];
66 	int err, cfd = client_fd;
67 	struct mptcp_storage val;
68 
69 	if (!ASSERT_GT(token, 0, "invalid token"))
70 		return -1;
71 
72 	get_msk_ca_name(ca_name);
73 
74 	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
75 	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
76 		return err;
77 
78 	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
79 		err++;
80 
81 	if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
82 		err++;
83 
84 	if (!ASSERT_EQ(val.token, token, "unexpected token"))
85 		err++;
86 
87 	if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
88 		err++;
89 
90 	if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
91 		err++;
92 
93 	return err;
94 }
95 
96 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
97 {
98 	int client_fd, prog_fd, map_fd, err;
99 	struct mptcp_sock *sock_skel;
100 
101 	sock_skel = mptcp_sock__open_and_load();
102 	if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
103 		return -EIO;
104 
105 	err = mptcp_sock__attach(sock_skel);
106 	if (!ASSERT_OK(err, "skel_attach"))
107 		goto out;
108 
109 	prog_fd = bpf_program__fd(sock_skel->progs._sockops);
110 	if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) {
111 		err = -EIO;
112 		goto out;
113 	}
114 
115 	map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
116 	if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) {
117 		err = -EIO;
118 		goto out;
119 	}
120 
121 	err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
122 	if (!ASSERT_OK(err, "bpf_prog_attach"))
123 		goto out;
124 
125 	client_fd = connect_to_fd(server_fd, 0);
126 	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
127 		err = -EIO;
128 		goto out;
129 	}
130 
131 	err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
132 			  verify_tsk(map_fd, client_fd);
133 
134 	close(client_fd);
135 
136 out:
137 	mptcp_sock__destroy(sock_skel);
138 	return err;
139 }
140 
141 static void test_base(void)
142 {
143 	struct nstoken *nstoken = NULL;
144 	int server_fd, cgroup_fd;
145 
146 	cgroup_fd = test__join_cgroup("/mptcp");
147 	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
148 		return;
149 
150 	SYS(fail, "ip netns add %s", NS_TEST);
151 	SYS(fail, "ip -net %s link set dev lo up", NS_TEST);
152 
153 	nstoken = open_netns(NS_TEST);
154 	if (!ASSERT_OK_PTR(nstoken, "open_netns"))
155 		goto fail;
156 
157 	/* without MPTCP */
158 	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
159 	if (!ASSERT_GE(server_fd, 0, "start_server"))
160 		goto with_mptcp;
161 
162 	ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
163 
164 	close(server_fd);
165 
166 with_mptcp:
167 	/* with MPTCP */
168 	server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
169 	if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
170 		goto fail;
171 
172 	ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
173 
174 	close(server_fd);
175 
176 fail:
177 	if (nstoken)
178 		close_netns(nstoken);
179 
180 	SYS_NOFAIL("ip netns del " NS_TEST " &> /dev/null");
181 
182 	close(cgroup_fd);
183 }
184 
185 void test_mptcp(void)
186 {
187 	if (test__start_subtest("base"))
188 		test_base();
189 }
190