1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2022 Meta Platforms, Inc. and affiliates. */
3 
4 #include <linux/bpf.h>
5 #include <bpf/bpf_helpers.h>
6 #include "bpf_misc.h"
7 #include "test_user_ringbuf.h"
8 
9 char _license[] SEC("license") = "GPL";
10 
11 struct {
12 	__uint(type, BPF_MAP_TYPE_USER_RINGBUF);
13 } user_ringbuf SEC(".maps");
14 
15 struct {
16 	__uint(type, BPF_MAP_TYPE_RINGBUF);
17 } kernel_ringbuf SEC(".maps");
18 
19 /* inputs */
20 int pid, err, val;
21 
22 int read = 0;
23 
24 /* Counter used for end-to-end protocol test */
25 __u64 kern_mutated = 0;
26 __u64 user_mutated = 0;
27 __u64 expected_user_mutated = 0;
28 
29 static int
is_test_process(void)30 is_test_process(void)
31 {
32 	int cur_pid = bpf_get_current_pid_tgid() >> 32;
33 
34 	return cur_pid == pid;
35 }
36 
37 static long
record_sample(struct bpf_dynptr * dynptr,void * context)38 record_sample(struct bpf_dynptr *dynptr, void *context)
39 {
40 	const struct sample *sample = NULL;
41 	struct sample stack_sample;
42 	int status;
43 	static int num_calls;
44 
45 	if (num_calls++ % 2 == 0) {
46 		status = bpf_dynptr_read(&stack_sample, sizeof(stack_sample), dynptr, 0, 0);
47 		if (status) {
48 			bpf_printk("bpf_dynptr_read() failed: %d\n", status);
49 			err = 1;
50 			return 1;
51 		}
52 	} else {
53 		sample = bpf_dynptr_data(dynptr, 0, sizeof(*sample));
54 		if (!sample) {
55 			bpf_printk("Unexpectedly failed to get sample\n");
56 			err = 2;
57 			return 1;
58 		}
59 		stack_sample = *sample;
60 	}
61 
62 	__sync_fetch_and_add(&read, 1);
63 	return 0;
64 }
65 
66 static void
handle_sample_msg(const struct test_msg * msg)67 handle_sample_msg(const struct test_msg *msg)
68 {
69 	switch (msg->msg_op) {
70 	case TEST_MSG_OP_INC64:
71 		kern_mutated += msg->operand_64;
72 		break;
73 	case TEST_MSG_OP_INC32:
74 		kern_mutated += msg->operand_32;
75 		break;
76 	case TEST_MSG_OP_MUL64:
77 		kern_mutated *= msg->operand_64;
78 		break;
79 	case TEST_MSG_OP_MUL32:
80 		kern_mutated *= msg->operand_32;
81 		break;
82 	default:
83 		bpf_printk("Unrecognized op %d\n", msg->msg_op);
84 		err = 2;
85 	}
86 }
87 
88 static long
read_protocol_msg(struct bpf_dynptr * dynptr,void * context)89 read_protocol_msg(struct bpf_dynptr *dynptr, void *context)
90 {
91 	const struct test_msg *msg = NULL;
92 
93 	msg = bpf_dynptr_data(dynptr, 0, sizeof(*msg));
94 	if (!msg) {
95 		err = 1;
96 		bpf_printk("Unexpectedly failed to get msg\n");
97 		return 0;
98 	}
99 
100 	handle_sample_msg(msg);
101 
102 	return 0;
103 }
104 
publish_next_kern_msg(__u32 index,void * context)105 static int publish_next_kern_msg(__u32 index, void *context)
106 {
107 	struct test_msg *msg = NULL;
108 	int operand_64 = TEST_OP_64;
109 	int operand_32 = TEST_OP_32;
110 
111 	msg = bpf_ringbuf_reserve(&kernel_ringbuf, sizeof(*msg), 0);
112 	if (!msg) {
113 		err = 4;
114 		return 1;
115 	}
116 
117 	switch (index % TEST_MSG_OP_NUM_OPS) {
118 	case TEST_MSG_OP_INC64:
119 		msg->operand_64 = operand_64;
120 		msg->msg_op = TEST_MSG_OP_INC64;
121 		expected_user_mutated += operand_64;
122 		break;
123 	case TEST_MSG_OP_INC32:
124 		msg->operand_32 = operand_32;
125 		msg->msg_op = TEST_MSG_OP_INC32;
126 		expected_user_mutated += operand_32;
127 		break;
128 	case TEST_MSG_OP_MUL64:
129 		msg->operand_64 = operand_64;
130 		msg->msg_op = TEST_MSG_OP_MUL64;
131 		expected_user_mutated *= operand_64;
132 		break;
133 	case TEST_MSG_OP_MUL32:
134 		msg->operand_32 = operand_32;
135 		msg->msg_op = TEST_MSG_OP_MUL32;
136 		expected_user_mutated *= operand_32;
137 		break;
138 	default:
139 		bpf_ringbuf_discard(msg, 0);
140 		err = 5;
141 		return 1;
142 	}
143 
144 	bpf_ringbuf_submit(msg, 0);
145 
146 	return 0;
147 }
148 
149 static void
publish_kern_messages(void)150 publish_kern_messages(void)
151 {
152 	if (expected_user_mutated != user_mutated) {
153 		bpf_printk("%lu != %lu\n", expected_user_mutated, user_mutated);
154 		err = 3;
155 		return;
156 	}
157 
158 	bpf_loop(8, publish_next_kern_msg, NULL, 0);
159 }
160 
161 SEC("fentry/" SYS_PREFIX "sys_prctl")
test_user_ringbuf_protocol(void * ctx)162 int test_user_ringbuf_protocol(void *ctx)
163 {
164 	long status = 0;
165 
166 	if (!is_test_process())
167 		return 0;
168 
169 	status = bpf_user_ringbuf_drain(&user_ringbuf, read_protocol_msg, NULL, 0);
170 	if (status < 0) {
171 		bpf_printk("Drain returned: %ld\n", status);
172 		err = 1;
173 		return 0;
174 	}
175 
176 	publish_kern_messages();
177 
178 	return 0;
179 }
180 
181 SEC("fentry/" SYS_PREFIX "sys_getpgid")
test_user_ringbuf(void * ctx)182 int test_user_ringbuf(void *ctx)
183 {
184 	if (!is_test_process())
185 		return 0;
186 
187 	err = bpf_user_ringbuf_drain(&user_ringbuf, record_sample, NULL, 0);
188 
189 	return 0;
190 }
191 
192 static long
do_nothing_cb(struct bpf_dynptr * dynptr,void * context)193 do_nothing_cb(struct bpf_dynptr *dynptr, void *context)
194 {
195 	__sync_fetch_and_add(&read, 1);
196 	return 0;
197 }
198 
199 SEC("fentry/" SYS_PREFIX "sys_prlimit64")
test_user_ringbuf_epoll(void * ctx)200 int test_user_ringbuf_epoll(void *ctx)
201 {
202 	long num_samples;
203 
204 	if (!is_test_process())
205 		return 0;
206 
207 	num_samples = bpf_user_ringbuf_drain(&user_ringbuf, do_nothing_cb, NULL, 0);
208 	if (num_samples <= 0)
209 		err = 1;
210 
211 	return 0;
212 }
213