1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright 2018 Google Inc.
4  * Author: Soheil Hassas Yeganeh (soheil@google.com)
5  *
6  * Simple example on how to use TCP_INQ and TCP_CM_INQ.
7  */
8 #define _GNU_SOURCE
9 
10 #include <error.h>
11 #include <netinet/in.h>
12 #include <netinet/tcp.h>
13 #include <pthread.h>
14 #include <stdio.h>
15 #include <errno.h>
16 #include <stdlib.h>
17 #include <string.h>
18 #include <sys/socket.h>
19 #include <unistd.h>
20 
21 #ifndef TCP_INQ
22 #define TCP_INQ 36
23 #endif
24 
25 #ifndef TCP_CM_INQ
26 #define TCP_CM_INQ TCP_INQ
27 #endif
28 
29 #define BUF_SIZE 8192
30 #define CMSG_SIZE 32
31 
32 static int family = AF_INET6;
33 static socklen_t addr_len = sizeof(struct sockaddr_in6);
34 static int port = 4974;
35 
36 static void setup_loopback_addr(int family, struct sockaddr_storage *sockaddr)
37 {
38 	struct sockaddr_in6 *addr6 = (void *) sockaddr;
39 	struct sockaddr_in *addr4 = (void *) sockaddr;
40 
41 	switch (family) {
42 	case PF_INET:
43 		memset(addr4, 0, sizeof(*addr4));
44 		addr4->sin_family = AF_INET;
45 		addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
46 		addr4->sin_port = htons(port);
47 		break;
48 	case PF_INET6:
49 		memset(addr6, 0, sizeof(*addr6));
50 		addr6->sin6_family = AF_INET6;
51 		addr6->sin6_addr = in6addr_loopback;
52 		addr6->sin6_port = htons(port);
53 		break;
54 	default:
55 		error(1, 0, "illegal family");
56 	}
57 }
58 
59 void *start_server(void *arg)
60 {
61 	int server_fd = (int)(unsigned long)arg;
62 	struct sockaddr_in addr;
63 	socklen_t addrlen = sizeof(addr);
64 	char *buf;
65 	int fd;
66 	int r;
67 
68 	buf = malloc(BUF_SIZE);
69 
70 	for (;;) {
71 		fd = accept(server_fd, (struct sockaddr *)&addr, &addrlen);
72 		if (fd == -1) {
73 			perror("accept");
74 			break;
75 		}
76 		do {
77 			r = send(fd, buf, BUF_SIZE, 0);
78 		} while (r < 0 && errno == EINTR);
79 		if (r < 0)
80 			perror("send");
81 		if (r != BUF_SIZE)
82 			fprintf(stderr, "can only send %d bytes\n", r);
83 		/* TCP_INQ can overestimate in-queue by one byte if we send
84 		 * the FIN packet. Sleep for 1 second, so that the client
85 		 * likely invoked recvmsg().
86 		 */
87 		sleep(1);
88 		close(fd);
89 	}
90 
91 	free(buf);
92 	close(server_fd);
93 	pthread_exit(0);
94 }
95 
96 int main(int argc, char *argv[])
97 {
98 	struct sockaddr_storage listen_addr, addr;
99 	int c, one = 1, inq = -1;
100 	pthread_t server_thread;
101 	char cmsgbuf[CMSG_SIZE];
102 	struct iovec iov[1];
103 	struct cmsghdr *cm;
104 	struct msghdr msg;
105 	int server_fd, fd;
106 	char *buf;
107 
108 	while ((c = getopt(argc, argv, "46p:")) != -1) {
109 		switch (c) {
110 		case '4':
111 			family = PF_INET;
112 			addr_len = sizeof(struct sockaddr_in);
113 			break;
114 		case '6':
115 			family = PF_INET6;
116 			addr_len = sizeof(struct sockaddr_in6);
117 			break;
118 		case 'p':
119 			port = atoi(optarg);
120 			break;
121 		}
122 	}
123 
124 	server_fd = socket(family, SOCK_STREAM, 0);
125 	if (server_fd < 0)
126 		error(1, errno, "server socket");
127 	setup_loopback_addr(family, &listen_addr);
128 	if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR,
129 		       &one, sizeof(one)) != 0)
130 		error(1, errno, "setsockopt(SO_REUSEADDR)");
131 	if (bind(server_fd, (const struct sockaddr *)&listen_addr,
132 		 addr_len) == -1)
133 		error(1, errno, "bind");
134 	if (listen(server_fd, 128) == -1)
135 		error(1, errno, "listen");
136 	if (pthread_create(&server_thread, NULL, start_server,
137 			   (void *)(unsigned long)server_fd) != 0)
138 		error(1, errno, "pthread_create");
139 
140 	fd = socket(family, SOCK_STREAM, 0);
141 	if (fd < 0)
142 		error(1, errno, "client socket");
143 	setup_loopback_addr(family, &addr);
144 	if (connect(fd, (const struct sockaddr *)&addr, addr_len) == -1)
145 		error(1, errno, "connect");
146 	if (setsockopt(fd, SOL_TCP, TCP_INQ, &one, sizeof(one)) != 0)
147 		error(1, errno, "setsockopt(TCP_INQ)");
148 
149 	msg.msg_name = NULL;
150 	msg.msg_namelen = 0;
151 	msg.msg_iov = iov;
152 	msg.msg_iovlen = 1;
153 	msg.msg_control = cmsgbuf;
154 	msg.msg_controllen = sizeof(cmsgbuf);
155 	msg.msg_flags = 0;
156 
157 	buf = malloc(BUF_SIZE);
158 	iov[0].iov_base = buf;
159 	iov[0].iov_len = BUF_SIZE / 2;
160 
161 	if (recvmsg(fd, &msg, 0) != iov[0].iov_len)
162 		error(1, errno, "recvmsg");
163 	if (msg.msg_flags & MSG_CTRUNC)
164 		error(1, 0, "control message is truncated");
165 
166 	for (cm = CMSG_FIRSTHDR(&msg); cm; cm = CMSG_NXTHDR(&msg, cm))
167 		if (cm->cmsg_level == SOL_TCP && cm->cmsg_type == TCP_CM_INQ)
168 			inq = *((int *) CMSG_DATA(cm));
169 
170 	if (inq != BUF_SIZE - iov[0].iov_len) {
171 		fprintf(stderr, "unexpected inq: %d\n", inq);
172 		exit(1);
173 	}
174 
175 	printf("PASSED\n");
176 	free(buf);
177 	close(fd);
178 	return 0;
179 }
180