1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <sys/types.h>
4 #include <sys/epoll.h>
5 #include <sys/socket.h>
6 #include <linux/netlink.h>
7 #include <linux/connector.h>
8 #include <linux/cn_proc.h>
9 
10 #include <stddef.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <unistd.h>
14 #include <strings.h>
15 #include <errno.h>
16 #include <signal.h>
17 #include <string.h>
18 
19 #include "../kselftest.h"
20 
21 #define NL_MESSAGE_SIZE (sizeof(struct nlmsghdr) + sizeof(struct cn_msg) + \
22 			 sizeof(struct proc_input))
23 #define NL_MESSAGE_SIZE_NF (sizeof(struct nlmsghdr) + sizeof(struct cn_msg) + \
24 			 sizeof(int))
25 
26 #define MAX_EVENTS 1
27 
28 volatile static int interrupted;
29 static int nl_sock, ret_errno, tcount;
30 static struct epoll_event evn;
31 
32 static int filter;
33 
34 #ifdef ENABLE_PRINTS
35 #define Printf printf
36 #else
37 #define Printf ksft_print_msg
38 #endif
39 
40 int send_message(void *pinp)
41 {
42 	char buff[NL_MESSAGE_SIZE];
43 	struct nlmsghdr *hdr;
44 	struct cn_msg *msg;
45 
46 	hdr = (struct nlmsghdr *)buff;
47 	if (filter)
48 		hdr->nlmsg_len = NL_MESSAGE_SIZE;
49 	else
50 		hdr->nlmsg_len = NL_MESSAGE_SIZE_NF;
51 	hdr->nlmsg_type = NLMSG_DONE;
52 	hdr->nlmsg_flags = 0;
53 	hdr->nlmsg_seq = 0;
54 	hdr->nlmsg_pid = getpid();
55 
56 	msg = (struct cn_msg *)NLMSG_DATA(hdr);
57 	msg->id.idx = CN_IDX_PROC;
58 	msg->id.val = CN_VAL_PROC;
59 	msg->seq = 0;
60 	msg->ack = 0;
61 	msg->flags = 0;
62 
63 	if (filter) {
64 		msg->len = sizeof(struct proc_input);
65 		((struct proc_input *)msg->data)->mcast_op =
66 			((struct proc_input *)pinp)->mcast_op;
67 		((struct proc_input *)msg->data)->event_type =
68 			((struct proc_input *)pinp)->event_type;
69 	} else {
70 		msg->len = sizeof(int);
71 		*(int *)msg->data = *(enum proc_cn_mcast_op *)pinp;
72 	}
73 
74 	if (send(nl_sock, hdr, hdr->nlmsg_len, 0) == -1) {
75 		ret_errno = errno;
76 		perror("send failed");
77 		return -3;
78 	}
79 	return 0;
80 }
81 
82 int register_proc_netlink(int *efd, void *input)
83 {
84 	struct sockaddr_nl sa_nl;
85 	int err = 0, epoll_fd;
86 
87 	nl_sock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_CONNECTOR);
88 
89 	if (nl_sock == -1) {
90 		ret_errno = errno;
91 		perror("socket failed");
92 		return -1;
93 	}
94 
95 	bzero(&sa_nl, sizeof(sa_nl));
96 	sa_nl.nl_family = AF_NETLINK;
97 	sa_nl.nl_groups = CN_IDX_PROC;
98 	sa_nl.nl_pid    = getpid();
99 
100 	if (bind(nl_sock, (struct sockaddr *)&sa_nl, sizeof(sa_nl)) == -1) {
101 		ret_errno = errno;
102 		perror("bind failed");
103 		return -2;
104 	}
105 
106 	epoll_fd = epoll_create1(EPOLL_CLOEXEC);
107 	if (epoll_fd < 0) {
108 		ret_errno = errno;
109 		perror("epoll_create1 failed");
110 		return -2;
111 	}
112 
113 	err = send_message(input);
114 
115 	if (err < 0)
116 		return err;
117 
118 	evn.events = EPOLLIN;
119 	evn.data.fd = nl_sock;
120 	if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, nl_sock, &evn) < 0) {
121 		ret_errno = errno;
122 		perror("epoll_ctl failed");
123 		return -3;
124 	}
125 	*efd = epoll_fd;
126 	return 0;
127 }
128 
129 static void sigint(int sig)
130 {
131 	interrupted = 1;
132 }
133 
134 int handle_packet(char *buff, int fd, struct proc_event *event)
135 {
136 	struct nlmsghdr *hdr;
137 
138 	hdr = (struct nlmsghdr *)buff;
139 
140 	if (hdr->nlmsg_type == NLMSG_ERROR) {
141 		perror("NLMSG_ERROR error\n");
142 		return -3;
143 	} else if (hdr->nlmsg_type == NLMSG_DONE) {
144 		event = (struct proc_event *)
145 			((struct cn_msg *)NLMSG_DATA(hdr))->data;
146 		tcount++;
147 		switch (event->what) {
148 		case PROC_EVENT_EXIT:
149 			Printf("Exit process %d (tgid %d) with code %d, signal %d\n",
150 			       event->event_data.exit.process_pid,
151 			       event->event_data.exit.process_tgid,
152 			       event->event_data.exit.exit_code,
153 			       event->event_data.exit.exit_signal);
154 			break;
155 		case PROC_EVENT_FORK:
156 			Printf("Fork process %d (tgid %d), parent %d (tgid %d)\n",
157 			       event->event_data.fork.child_pid,
158 			       event->event_data.fork.child_tgid,
159 			       event->event_data.fork.parent_pid,
160 			       event->event_data.fork.parent_tgid);
161 			break;
162 		case PROC_EVENT_EXEC:
163 			Printf("Exec process %d (tgid %d)\n",
164 			       event->event_data.exec.process_pid,
165 			       event->event_data.exec.process_tgid);
166 			break;
167 		case PROC_EVENT_UID:
168 			Printf("UID process %d (tgid %d) uid %d euid %d\n",
169 			       event->event_data.id.process_pid,
170 			       event->event_data.id.process_tgid,
171 			       event->event_data.id.r.ruid,
172 			       event->event_data.id.e.euid);
173 			break;
174 		case PROC_EVENT_GID:
175 			Printf("GID process %d (tgid %d) gid %d egid %d\n",
176 			       event->event_data.id.process_pid,
177 			       event->event_data.id.process_tgid,
178 			       event->event_data.id.r.rgid,
179 			       event->event_data.id.e.egid);
180 			break;
181 		case PROC_EVENT_SID:
182 			Printf("SID process %d (tgid %d)\n",
183 			       event->event_data.sid.process_pid,
184 			       event->event_data.sid.process_tgid);
185 			break;
186 		case PROC_EVENT_PTRACE:
187 			Printf("Ptrace process %d (tgid %d), Tracer %d (tgid %d)\n",
188 			       event->event_data.ptrace.process_pid,
189 			       event->event_data.ptrace.process_tgid,
190 			       event->event_data.ptrace.tracer_pid,
191 			       event->event_data.ptrace.tracer_tgid);
192 			break;
193 		case PROC_EVENT_COMM:
194 			Printf("Comm process %d (tgid %d) comm %s\n",
195 			       event->event_data.comm.process_pid,
196 			       event->event_data.comm.process_tgid,
197 			       event->event_data.comm.comm);
198 			break;
199 		case PROC_EVENT_COREDUMP:
200 			Printf("Coredump process %d (tgid %d) parent %d, (tgid %d)\n",
201 			       event->event_data.coredump.process_pid,
202 			       event->event_data.coredump.process_tgid,
203 			       event->event_data.coredump.parent_pid,
204 			       event->event_data.coredump.parent_tgid);
205 			break;
206 		default:
207 			break;
208 		}
209 	}
210 	return 0;
211 }
212 
213 int handle_events(int epoll_fd, struct proc_event *pev)
214 {
215 	char buff[CONNECTOR_MAX_MSG_SIZE];
216 	struct epoll_event ev[MAX_EVENTS];
217 	int i, event_count = 0, err = 0;
218 
219 	event_count = epoll_wait(epoll_fd, ev, MAX_EVENTS, -1);
220 	if (event_count < 0) {
221 		ret_errno = errno;
222 		if (ret_errno != EINTR)
223 			perror("epoll_wait failed");
224 		return -3;
225 	}
226 	for (i = 0; i < event_count; i++) {
227 		if (!(ev[i].events & EPOLLIN))
228 			continue;
229 		if (recv(ev[i].data.fd, buff, sizeof(buff), 0) == -1) {
230 			ret_errno = errno;
231 			perror("recv failed");
232 			return -3;
233 		}
234 		err = handle_packet(buff, ev[i].data.fd, pev);
235 		if (err < 0)
236 			return err;
237 	}
238 	return 0;
239 }
240 
241 int main(int argc, char *argv[])
242 {
243 	int epoll_fd, err;
244 	struct proc_event proc_ev;
245 	struct proc_input input;
246 
247 	signal(SIGINT, sigint);
248 
249 	if (argc > 2) {
250 		printf("Expected 0(assume no-filter) or 1 argument(-f)\n");
251 		exit(KSFT_SKIP);
252 	}
253 
254 	if (argc == 2) {
255 		if (strcmp(argv[1], "-f") == 0) {
256 			filter = 1;
257 		} else {
258 			printf("Valid option : -f (for filter feature)\n");
259 			exit(KSFT_SKIP);
260 		}
261 	}
262 
263 	if (filter) {
264 		input.event_type = PROC_EVENT_NONZERO_EXIT;
265 		input.mcast_op = PROC_CN_MCAST_LISTEN;
266 		err = register_proc_netlink(&epoll_fd, (void*)&input);
267 	} else {
268 		enum proc_cn_mcast_op op = PROC_CN_MCAST_LISTEN;
269 		err = register_proc_netlink(&epoll_fd, (void*)&op);
270 	}
271 
272 	if (err < 0) {
273 		if (err == -2)
274 			close(nl_sock);
275 		if (err == -3) {
276 			close(nl_sock);
277 			close(epoll_fd);
278 		}
279 		exit(1);
280 	}
281 
282 	while (!interrupted) {
283 		err = handle_events(epoll_fd, &proc_ev);
284 		if (err < 0) {
285 			if (ret_errno == EINTR)
286 				continue;
287 			if (err == -2)
288 				close(nl_sock);
289 			if (err == -3) {
290 				close(nl_sock);
291 				close(epoll_fd);
292 			}
293 			exit(1);
294 		}
295 	}
296 
297 	if (filter) {
298 		input.mcast_op = PROC_CN_MCAST_IGNORE;
299 		send_message((void*)&input);
300 	} else {
301 		enum proc_cn_mcast_op op = PROC_CN_MCAST_IGNORE;
302 		send_message((void*)&op);
303 	}
304 
305 	close(epoll_fd);
306 	close(nl_sock);
307 
308 	printf("Done total count: %d\n", tcount);
309 	exit(0);
310 }
311