xref: /openbmc/libpldm/src/transport/mctp-demux.c (revision 4e1ba8a7)
1 #include "mctp-defines.h"
2 #include "base.h"
3 #include "container-of.h"
4 #include "libpldm/pldm.h"
5 #include "libpldm/transport.h"
6 #include "socket.h"
7 #include "transport.h"
8 
9 #include <errno.h>
10 #include <limits.h>
11 #include <poll.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <sys/socket.h>
15 #include <sys/types.h>
16 #include <sys/un.h>
17 #include <unistd.h>
18 
19 #define MCTP_DEMUX_NAME "libmctp-demux-daemon"
20 const uint8_t mctp_msg_type = MCTP_MSG_TYPE_PLDM;
21 
22 struct pldm_transport_mctp_demux {
23 	struct pldm_transport transport;
24 	int socket;
25 	/* In the future this probably needs to move to a tid-eid-uuid/network
26 	 * id mapping for multi mctp networks */
27 	pldm_tid_t tid_eid_map[MCTP_MAX_NUM_EID];
28 	struct pldm_socket_sndbuf socket_send_buf;
29 };
30 
31 #define transport_to_demux(ptr)                                                \
32 	container_of(ptr, struct pldm_transport_mctp_demux, transport)
33 
34 LIBPLDM_ABI_TESTING
35 struct pldm_transport *
36 pldm_transport_mctp_demux_core(struct pldm_transport_mctp_demux *ctx)
37 {
38 	return &ctx->transport;
39 }
40 
41 static pldm_requester_rc_t pldm_transport_mctp_demux_open(void)
42 {
43 	int fd = -1;
44 	ssize_t rc = -1;
45 
46 	fd = socket(AF_UNIX, SOCK_SEQPACKET, 0);
47 	if (fd == -1) {
48 		return fd;
49 	}
50 
51 	const char path[] = "\0mctp-mux";
52 	struct sockaddr_un addr;
53 	addr.sun_family = AF_UNIX;
54 	memcpy(addr.sun_path, path, sizeof(path) - 1);
55 	rc = connect(fd, (struct sockaddr *)&addr,
56 		     sizeof(path) + sizeof(addr.sun_family) - 1);
57 	if (rc == -1) {
58 		return PLDM_REQUESTER_OPEN_FAIL;
59 	}
60 	rc = write(fd, &mctp_msg_type, sizeof(mctp_msg_type));
61 	if (rc == -1) {
62 		return PLDM_REQUESTER_OPEN_FAIL;
63 	}
64 
65 	return fd;
66 }
67 
68 LIBPLDM_ABI_TESTING
69 int pldm_transport_mctp_demux_init_pollfd(struct pldm_transport *t,
70 					  struct pollfd *pollfd)
71 {
72 	struct pldm_transport_mctp_demux *ctx = transport_to_demux(t);
73 	pollfd->fd = ctx->socket;
74 	pollfd->events = POLLIN;
75 	return 0;
76 }
77 
78 static int
79 pldm_transport_mctp_demux_get_eid(struct pldm_transport_mctp_demux *ctx,
80 				  pldm_tid_t tid, mctp_eid_t *eid)
81 {
82 	int i;
83 	for (i = 0; i < MCTP_MAX_NUM_EID; i++) {
84 		if (ctx->tid_eid_map[i] == tid) {
85 			*eid = i;
86 			return 0;
87 		}
88 	}
89 	*eid = -1;
90 	return -1;
91 }
92 
93 LIBPLDM_ABI_TESTING
94 int pldm_transport_mctp_demux_map_tid(struct pldm_transport_mctp_demux *ctx,
95 				      pldm_tid_t tid, mctp_eid_t eid)
96 {
97 	ctx->tid_eid_map[eid] = tid;
98 
99 	return 0;
100 }
101 
102 LIBPLDM_ABI_TESTING
103 int pldm_transport_mctp_demux_unmap_tid(struct pldm_transport_mctp_demux *ctx,
104 					__attribute__((unused)) pldm_tid_t tid,
105 					mctp_eid_t eid)
106 {
107 	ctx->tid_eid_map[eid] = 0;
108 
109 	return 0;
110 }
111 
112 static pldm_requester_rc_t
113 pldm_transport_mctp_demux_recv(struct pldm_transport *t, pldm_tid_t tid,
114 			       void **pldm_resp_msg, size_t *resp_msg_len)
115 {
116 	struct pldm_transport_mctp_demux *demux = transport_to_demux(t);
117 	mctp_eid_t eid = 0;
118 	int rc = pldm_transport_mctp_demux_get_eid(demux, tid, &eid);
119 	if (rc) {
120 		return PLDM_REQUESTER_RECV_FAIL;
121 	}
122 
123 	ssize_t min_len = sizeof(eid) + sizeof(mctp_msg_type) +
124 			  sizeof(struct pldm_msg_hdr);
125 	ssize_t length = recv(demux->socket, NULL, 0, MSG_PEEK | MSG_TRUNC);
126 	if (length <= 0) {
127 		return PLDM_REQUESTER_RECV_FAIL;
128 	}
129 	uint8_t *buf = malloc(length);
130 	if (buf == NULL) {
131 		return PLDM_REQUESTER_RECV_FAIL;
132 	}
133 	if (length < min_len) {
134 		/* read and discard */
135 		recv(demux->socket, buf, length, 0);
136 		free(buf);
137 		return PLDM_REQUESTER_INVALID_RECV_LEN;
138 	}
139 	struct iovec iov[2];
140 	uint8_t mctp_prefix[2];
141 	size_t mctp_prefix_len = 2;
142 	size_t pldm_len = length - mctp_prefix_len;
143 	iov[0].iov_len = mctp_prefix_len;
144 	iov[0].iov_base = mctp_prefix;
145 	iov[1].iov_len = pldm_len;
146 	iov[1].iov_base = buf;
147 	struct msghdr msg = { 0 };
148 	msg.msg_iov = iov;
149 	msg.msg_iovlen = sizeof(iov) / sizeof(iov[0]);
150 	ssize_t bytes = recvmsg(demux->socket, &msg, 0);
151 	if (length != bytes) {
152 		free(buf);
153 		return PLDM_REQUESTER_INVALID_RECV_LEN;
154 	}
155 	if ((mctp_prefix[0] != eid) || (mctp_prefix[1] != mctp_msg_type)) {
156 		free(buf);
157 		return PLDM_REQUESTER_NOT_PLDM_MSG;
158 	}
159 	*pldm_resp_msg = buf;
160 	*resp_msg_len = pldm_len;
161 	return PLDM_REQUESTER_SUCCESS;
162 }
163 
164 static pldm_requester_rc_t
165 pldm_transport_mctp_demux_send(struct pldm_transport *t, pldm_tid_t tid,
166 			       const void *pldm_req_msg, size_t req_msg_len)
167 {
168 	struct pldm_transport_mctp_demux *demux = transport_to_demux(t);
169 	mctp_eid_t eid = 0;
170 	if (pldm_transport_mctp_demux_get_eid(demux, tid, &eid)) {
171 		return PLDM_REQUESTER_SEND_FAIL;
172 	}
173 
174 	uint8_t hdr[2] = { eid, mctp_msg_type };
175 
176 	struct iovec iov[2];
177 	iov[0].iov_base = hdr;
178 	iov[0].iov_len = sizeof(hdr);
179 	iov[1].iov_base = (uint8_t *)pldm_req_msg;
180 	iov[1].iov_len = req_msg_len;
181 
182 	struct msghdr msg = { 0 };
183 	msg.msg_iov = iov;
184 	msg.msg_iovlen = sizeof(iov) / sizeof(iov[0]);
185 
186 	if (req_msg_len > INT_MAX ||
187 	    pldm_socket_sndbuf_accomodate(&(demux->socket_send_buf),
188 					  (int)req_msg_len)) {
189 		return PLDM_REQUESTER_SEND_FAIL;
190 	}
191 
192 	ssize_t rc = sendmsg(demux->socket, &msg, 0);
193 	if (rc == -1) {
194 		return PLDM_REQUESTER_SEND_FAIL;
195 	}
196 	return PLDM_REQUESTER_SUCCESS;
197 }
198 
199 LIBPLDM_ABI_TESTING
200 int pldm_transport_mctp_demux_init(struct pldm_transport_mctp_demux **ctx)
201 {
202 	if (!ctx || *ctx) {
203 		return -EINVAL;
204 	}
205 
206 	struct pldm_transport_mctp_demux *demux =
207 		calloc(1, sizeof(struct pldm_transport_mctp_demux));
208 	if (!demux) {
209 		return -ENOMEM;
210 	}
211 
212 	demux->transport.name = MCTP_DEMUX_NAME;
213 	demux->transport.version = 1;
214 	demux->transport.recv = pldm_transport_mctp_demux_recv;
215 	demux->transport.send = pldm_transport_mctp_demux_send;
216 	demux->transport.init_pollfd = pldm_transport_mctp_demux_init_pollfd;
217 	demux->socket = pldm_transport_mctp_demux_open();
218 	if (demux->socket == -1) {
219 		free(demux);
220 		return -1;
221 	}
222 
223 	if (pldm_socket_sndbuf_init(&demux->socket_send_buf, demux->socket)) {
224 		close(demux->socket);
225 		free(demux);
226 		return -1;
227 	}
228 
229 	*ctx = demux;
230 	return 0;
231 }
232 
233 LIBPLDM_ABI_TESTING
234 void pldm_transport_mctp_demux_destroy(struct pldm_transport_mctp_demux *ctx)
235 {
236 	if (!ctx) {
237 		return;
238 	}
239 	close(ctx->socket);
240 	free(ctx);
241 }
242 
243 /* Temporary for old API */
244 LIBPLDM_ABI_TESTING
245 struct pldm_transport_mctp_demux *
246 pldm_transport_mctp_demux_init_with_fd(int mctp_fd)
247 {
248 	struct pldm_transport_mctp_demux *demux =
249 		calloc(1, sizeof(struct pldm_transport_mctp_demux));
250 	if (!demux) {
251 		return NULL;
252 	}
253 
254 	demux->transport.name = MCTP_DEMUX_NAME;
255 	demux->transport.version = 1;
256 	demux->transport.recv = pldm_transport_mctp_demux_recv;
257 	demux->transport.send = pldm_transport_mctp_demux_send;
258 	demux->transport.init_pollfd = pldm_transport_mctp_demux_init_pollfd;
259 	/* dup is so we can call pldm_transport_mctp_demux_destroy which closes
260 	 * the socket, without closing the fd that is being used by the consumer
261 	 */
262 	demux->socket = dup(mctp_fd);
263 	if (demux->socket == -1) {
264 		free(demux);
265 		return NULL;
266 	}
267 
268 	if (pldm_socket_sndbuf_init(&demux->socket_send_buf, demux->socket)) {
269 		close(demux->socket);
270 		free(demux);
271 		return NULL;
272 	}
273 
274 	return demux;
275 }
276 
277 LIBPLDM_ABI_TESTING
278 int pldm_transport_mctp_demux_get_socket_fd(
279 	struct pldm_transport_mctp_demux *ctx)
280 {
281 	if (ctx) {
282 		return ctx->socket;
283 	}
284 
285 	return -1;
286 }
287