1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 
4 #include <stdlib.h>
5 #include <string.h>
6 #include <unistd.h>
7 
8 #include <arpa/inet.h>
9 #include <net/if.h>
10 #include <netinet/in.h>
11 #include <sys/socket.h>
12 #include <sys/types.h>
13 
14 
15 #include <bpf/bpf.h>
16 #include <bpf/libbpf.h>
17 
18 #include "cgroup_helpers.h"
19 
20 #define CGROUP_PATH		"/skb_cgroup_test"
21 #define NUM_CGROUP_LEVELS	4
22 
23 /* RFC 4291, Section 2.7.1 */
24 #define LINKLOCAL_MULTICAST	"ff02::1"
25 
mk_dst_addr(const char * ip,const char * iface,struct sockaddr_in6 * dst)26 static int mk_dst_addr(const char *ip, const char *iface,
27 		       struct sockaddr_in6 *dst)
28 {
29 	memset(dst, 0, sizeof(*dst));
30 
31 	dst->sin6_family = AF_INET6;
32 	dst->sin6_port = htons(1025);
33 
34 	if (inet_pton(AF_INET6, ip, &dst->sin6_addr) != 1) {
35 		log_err("Invalid IPv6: %s", ip);
36 		return -1;
37 	}
38 
39 	dst->sin6_scope_id = if_nametoindex(iface);
40 	if (!dst->sin6_scope_id) {
41 		log_err("Failed to get index of iface: %s", iface);
42 		return -1;
43 	}
44 
45 	return 0;
46 }
47 
send_packet(const char * iface)48 static int send_packet(const char *iface)
49 {
50 	struct sockaddr_in6 dst;
51 	char msg[] = "msg";
52 	int err = 0;
53 	int fd = -1;
54 
55 	if (mk_dst_addr(LINKLOCAL_MULTICAST, iface, &dst))
56 		goto err;
57 
58 	fd = socket(AF_INET6, SOCK_DGRAM, 0);
59 	if (fd == -1) {
60 		log_err("Failed to create UDP socket");
61 		goto err;
62 	}
63 
64 	if (sendto(fd, &msg, sizeof(msg), 0, (const struct sockaddr *)&dst,
65 		   sizeof(dst)) == -1) {
66 		log_err("Failed to send datagram");
67 		goto err;
68 	}
69 
70 	goto out;
71 err:
72 	err = -1;
73 out:
74 	if (fd >= 0)
75 		close(fd);
76 	return err;
77 }
78 
get_map_fd_by_prog_id(int prog_id)79 int get_map_fd_by_prog_id(int prog_id)
80 {
81 	struct bpf_prog_info info = {};
82 	__u32 info_len = sizeof(info);
83 	__u32 map_ids[1];
84 	int prog_fd = -1;
85 	int map_fd = -1;
86 
87 	prog_fd = bpf_prog_get_fd_by_id(prog_id);
88 	if (prog_fd < 0) {
89 		log_err("Failed to get fd by prog id %d", prog_id);
90 		goto err;
91 	}
92 
93 	info.nr_map_ids = 1;
94 	info.map_ids = (__u64) (unsigned long) map_ids;
95 
96 	if (bpf_prog_get_info_by_fd(prog_fd, &info, &info_len)) {
97 		log_err("Failed to get info by prog fd %d", prog_fd);
98 		goto err;
99 	}
100 
101 	if (!info.nr_map_ids) {
102 		log_err("No maps found for prog fd %d", prog_fd);
103 		goto err;
104 	}
105 
106 	map_fd = bpf_map_get_fd_by_id(map_ids[0]);
107 	if (map_fd < 0)
108 		log_err("Failed to get fd by map id %d", map_ids[0]);
109 err:
110 	if (prog_fd >= 0)
111 		close(prog_fd);
112 	return map_fd;
113 }
114 
check_ancestor_cgroup_ids(int prog_id)115 int check_ancestor_cgroup_ids(int prog_id)
116 {
117 	__u64 actual_ids[NUM_CGROUP_LEVELS], expected_ids[NUM_CGROUP_LEVELS];
118 	__u32 level;
119 	int err = 0;
120 	int map_fd;
121 
122 	expected_ids[0] = get_cgroup_id("/..");	/* root cgroup */
123 	expected_ids[1] = get_cgroup_id("");
124 	expected_ids[2] = get_cgroup_id(CGROUP_PATH);
125 	expected_ids[3] = 0; /* non-existent cgroup */
126 
127 	map_fd = get_map_fd_by_prog_id(prog_id);
128 	if (map_fd < 0)
129 		goto err;
130 
131 	for (level = 0; level < NUM_CGROUP_LEVELS; ++level) {
132 		if (bpf_map_lookup_elem(map_fd, &level, &actual_ids[level])) {
133 			log_err("Failed to lookup key %d", level);
134 			goto err;
135 		}
136 		if (actual_ids[level] != expected_ids[level]) {
137 			log_err("%llx (actual) != %llx (expected), level: %u\n",
138 				actual_ids[level], expected_ids[level], level);
139 			goto err;
140 		}
141 	}
142 
143 	goto out;
144 err:
145 	err = -1;
146 out:
147 	if (map_fd >= 0)
148 		close(map_fd);
149 	return err;
150 }
151 
main(int argc,char ** argv)152 int main(int argc, char **argv)
153 {
154 	int cgfd = -1;
155 	int err = 0;
156 
157 	if (argc < 3) {
158 		fprintf(stderr, "Usage: %s iface prog_id\n", argv[0]);
159 		exit(EXIT_FAILURE);
160 	}
161 
162 	/* Use libbpf 1.0 API mode */
163 	libbpf_set_strict_mode(LIBBPF_STRICT_ALL);
164 
165 	cgfd = cgroup_setup_and_join(CGROUP_PATH);
166 	if (cgfd < 0)
167 		goto err;
168 
169 	if (send_packet(argv[1]))
170 		goto err;
171 
172 	if (check_ancestor_cgroup_ids(atoi(argv[2])))
173 		goto err;
174 
175 	goto out;
176 err:
177 	err = -1;
178 out:
179 	close(cgfd);
180 	cleanup_cgroup_environment();
181 	printf("[%s]\n", err ? "FAIL" : "PASS");
182 	return err;
183 }
184