xref: /openbmc/qemu/backends/spdm-socket.c (revision bc419a1c)
1 /* SPDX-License-Identifier: BSD-3-Clause */
2 /*
3  * QEMU SPDM socket support
4  *
5  * This is based on:
6  * https://github.com/DMTF/spdm-emu/blob/07c0a838bcc1c6207c656ac75885c0603e344b6f/spdm_emu/spdm_emu_common/command.c
7  * but has been re-written to match QEMU style
8  *
9  * Copyright (c) 2021, DMTF. All rights reserved.
10  * Copyright (c) 2023. Western Digital Corporation or its affiliates.
11  */
12 
13 #include "qemu/osdep.h"
14 #include "sysemu/spdm-socket.h"
15 #include "qapi/error.h"
16 
read_bytes(const int socket,uint8_t * buffer,size_t number_of_bytes)17 static bool read_bytes(const int socket, uint8_t *buffer,
18                        size_t number_of_bytes)
19 {
20     ssize_t number_received = 0;
21     ssize_t result;
22 
23     while (number_received < number_of_bytes) {
24         result = recv(socket, buffer + number_received,
25                       number_of_bytes - number_received, 0);
26         if (result <= 0) {
27             return false;
28         }
29         number_received += result;
30     }
31     return true;
32 }
33 
read_data32(const int socket,uint32_t * data)34 static bool read_data32(const int socket, uint32_t *data)
35 {
36     bool result;
37 
38     result = read_bytes(socket, (uint8_t *)data, sizeof(uint32_t));
39     if (!result) {
40         return result;
41     }
42     *data = ntohl(*data);
43     return true;
44 }
45 
read_multiple_bytes(const int socket,uint8_t * buffer,uint32_t * bytes_received,uint32_t max_buffer_length)46 static bool read_multiple_bytes(const int socket, uint8_t *buffer,
47                                 uint32_t *bytes_received,
48                                 uint32_t max_buffer_length)
49 {
50     uint32_t length;
51     bool result;
52 
53     result = read_data32(socket, &length);
54     if (!result) {
55         return result;
56     }
57 
58     if (length > max_buffer_length) {
59         return false;
60     }
61 
62     if (bytes_received) {
63         *bytes_received = length;
64     }
65 
66     if (length == 0) {
67         return true;
68     }
69 
70     return read_bytes(socket, buffer, length);
71 }
72 
receive_platform_data(const int socket,uint32_t transport_type,uint32_t * command,uint8_t * receive_buffer,uint32_t * bytes_to_receive)73 static bool receive_platform_data(const int socket,
74                                   uint32_t transport_type,
75                                   uint32_t *command,
76                                   uint8_t *receive_buffer,
77                                   uint32_t *bytes_to_receive)
78 {
79     bool result;
80     uint32_t response;
81     uint32_t bytes_received;
82 
83     result = read_data32(socket, &response);
84     if (!result) {
85         return result;
86     }
87     *command = response;
88 
89     result = read_data32(socket, &transport_type);
90     if (!result) {
91         return result;
92     }
93 
94     bytes_received = 0;
95     result = read_multiple_bytes(socket, receive_buffer, &bytes_received,
96                                  *bytes_to_receive);
97     if (!result) {
98         return result;
99     }
100     *bytes_to_receive = bytes_received;
101 
102     return result;
103 }
104 
write_bytes(const int socket,const uint8_t * buffer,uint32_t number_of_bytes)105 static bool write_bytes(const int socket, const uint8_t *buffer,
106                         uint32_t number_of_bytes)
107 {
108     ssize_t number_sent = 0;
109     ssize_t result;
110 
111     while (number_sent < number_of_bytes) {
112         result = send(socket, buffer + number_sent,
113                       number_of_bytes - number_sent, 0);
114         if (result == -1) {
115             return false;
116         }
117         number_sent += result;
118     }
119     return true;
120 }
121 
write_data32(const int socket,uint32_t data)122 static bool write_data32(const int socket, uint32_t data)
123 {
124     data = htonl(data);
125     return write_bytes(socket, (uint8_t *)&data, sizeof(uint32_t));
126 }
127 
write_multiple_bytes(const int socket,const uint8_t * buffer,uint32_t bytes_to_send)128 static bool write_multiple_bytes(const int socket, const uint8_t *buffer,
129                                  uint32_t bytes_to_send)
130 {
131     bool result;
132 
133     result = write_data32(socket, bytes_to_send);
134     if (!result) {
135         return result;
136     }
137 
138     return write_bytes(socket, buffer, bytes_to_send);
139 }
140 
send_platform_data(const int socket,uint32_t transport_type,uint32_t command,const uint8_t * send_buffer,size_t bytes_to_send)141 static bool send_platform_data(const int socket,
142                                uint32_t transport_type, uint32_t command,
143                                const uint8_t *send_buffer, size_t bytes_to_send)
144 {
145     bool result;
146 
147     result = write_data32(socket, command);
148     if (!result) {
149         return result;
150     }
151 
152     result = write_data32(socket, transport_type);
153     if (!result) {
154         return result;
155     }
156 
157     return write_multiple_bytes(socket, send_buffer, bytes_to_send);
158 }
159 
spdm_socket_connect(uint16_t port,Error ** errp)160 int spdm_socket_connect(uint16_t port, Error **errp)
161 {
162     int client_socket;
163     struct sockaddr_in server_addr;
164 
165     client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
166     if (client_socket < 0) {
167         error_setg(errp, "cannot create socket: %s", strerror(errno));
168         return -1;
169     }
170 
171     memset((char *)&server_addr, 0, sizeof(server_addr));
172     server_addr.sin_family = AF_INET;
173     server_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
174     server_addr.sin_port = htons(port);
175 
176 
177     if (connect(client_socket, (struct sockaddr *)&server_addr,
178                 sizeof(server_addr)) < 0) {
179         error_setg(errp, "cannot connect: %s", strerror(errno));
180         close(client_socket);
181         return -1;
182     }
183 
184     return client_socket;
185 }
186 
spdm_socket_rsp(const int socket,uint32_t transport_type,void * req,uint32_t req_len,void * rsp,uint32_t rsp_len)187 uint32_t spdm_socket_rsp(const int socket, uint32_t transport_type,
188                          void *req, uint32_t req_len,
189                          void *rsp, uint32_t rsp_len)
190 {
191     uint32_t command;
192     bool result;
193 
194     result = send_platform_data(socket, transport_type,
195                                 SPDM_SOCKET_COMMAND_NORMAL,
196                                 req, req_len);
197     if (!result) {
198         return 0;
199     }
200 
201     result = receive_platform_data(socket, transport_type, &command,
202                                    (uint8_t *)rsp, &rsp_len);
203     if (!result) {
204         return 0;
205     }
206 
207     assert(command != 0);
208 
209     return rsp_len;
210 }
211 
spdm_socket_close(const int socket,uint32_t transport_type)212 void spdm_socket_close(const int socket, uint32_t transport_type)
213 {
214     send_platform_data(socket, transport_type,
215                        SPDM_SOCKET_COMMAND_SHUTDOWN, NULL, 0);
216 }
217