1 /* SPDX-License-Identifier: BSD-3-Clause */ 2 /* 3 * QEMU SPDM socket support 4 * 5 * This is based on: 6 * https://github.com/DMTF/spdm-emu/blob/07c0a838bcc1c6207c656ac75885c0603e344b6f/spdm_emu/spdm_emu_common/command.c 7 * but has been re-written to match QEMU style 8 * 9 * Copyright (c) 2021, DMTF. All rights reserved. 10 * Copyright (c) 2023. Western Digital Corporation or its affiliates. 11 */ 12 13 #include "qemu/osdep.h" 14 #include "sysemu/spdm-socket.h" 15 #include "qapi/error.h" 16 17 static bool read_bytes(const int socket, uint8_t *buffer, 18 size_t number_of_bytes) 19 { 20 ssize_t number_received = 0; 21 ssize_t result; 22 23 while (number_received < number_of_bytes) { 24 result = recv(socket, buffer + number_received, 25 number_of_bytes - number_received, 0); 26 if (result <= 0) { 27 return false; 28 } 29 number_received += result; 30 } 31 return true; 32 } 33 34 static bool read_data32(const int socket, uint32_t *data) 35 { 36 bool result; 37 38 result = read_bytes(socket, (uint8_t *)data, sizeof(uint32_t)); 39 if (!result) { 40 return result; 41 } 42 *data = ntohl(*data); 43 return true; 44 } 45 46 static bool read_multiple_bytes(const int socket, uint8_t *buffer, 47 uint32_t *bytes_received, 48 uint32_t max_buffer_length) 49 { 50 uint32_t length; 51 bool result; 52 53 result = read_data32(socket, &length); 54 if (!result) { 55 return result; 56 } 57 58 if (length > max_buffer_length) { 59 return false; 60 } 61 62 if (bytes_received) { 63 *bytes_received = length; 64 } 65 66 if (length == 0) { 67 return true; 68 } 69 70 return read_bytes(socket, buffer, length); 71 } 72 73 static bool receive_platform_data(const int socket, 74 uint32_t transport_type, 75 uint32_t *command, 76 uint8_t *receive_buffer, 77 uint32_t *bytes_to_receive) 78 { 79 bool result; 80 uint32_t response; 81 uint32_t bytes_received; 82 83 result = read_data32(socket, &response); 84 if (!result) { 85 return result; 86 } 87 *command = response; 88 89 result = read_data32(socket, &transport_type); 90 if (!result) { 91 return result; 92 } 93 94 bytes_received = 0; 95 result = read_multiple_bytes(socket, receive_buffer, &bytes_received, 96 *bytes_to_receive); 97 if (!result) { 98 return result; 99 } 100 *bytes_to_receive = bytes_received; 101 102 return result; 103 } 104 105 static bool write_bytes(const int socket, const uint8_t *buffer, 106 uint32_t number_of_bytes) 107 { 108 ssize_t number_sent = 0; 109 ssize_t result; 110 111 while (number_sent < number_of_bytes) { 112 result = send(socket, buffer + number_sent, 113 number_of_bytes - number_sent, 0); 114 if (result == -1) { 115 return false; 116 } 117 number_sent += result; 118 } 119 return true; 120 } 121 122 static bool write_data32(const int socket, uint32_t data) 123 { 124 data = htonl(data); 125 return write_bytes(socket, (uint8_t *)&data, sizeof(uint32_t)); 126 } 127 128 static bool write_multiple_bytes(const int socket, const uint8_t *buffer, 129 uint32_t bytes_to_send) 130 { 131 bool result; 132 133 result = write_data32(socket, bytes_to_send); 134 if (!result) { 135 return result; 136 } 137 138 return write_bytes(socket, buffer, bytes_to_send); 139 } 140 141 static bool send_platform_data(const int socket, 142 uint32_t transport_type, uint32_t command, 143 const uint8_t *send_buffer, size_t bytes_to_send) 144 { 145 bool result; 146 147 result = write_data32(socket, command); 148 if (!result) { 149 return result; 150 } 151 152 result = write_data32(socket, transport_type); 153 if (!result) { 154 return result; 155 } 156 157 return write_multiple_bytes(socket, send_buffer, bytes_to_send); 158 } 159 160 int spdm_socket_connect(uint16_t port, Error **errp) 161 { 162 int client_socket; 163 struct sockaddr_in server_addr; 164 165 client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); 166 if (client_socket < 0) { 167 error_setg(errp, "cannot create socket: %s", strerror(errno)); 168 return -1; 169 } 170 171 memset((char *)&server_addr, 0, sizeof(server_addr)); 172 server_addr.sin_family = AF_INET; 173 server_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 174 server_addr.sin_port = htons(port); 175 176 177 if (connect(client_socket, (struct sockaddr *)&server_addr, 178 sizeof(server_addr)) < 0) { 179 error_setg(errp, "cannot connect: %s", strerror(errno)); 180 close(client_socket); 181 return -1; 182 } 183 184 return client_socket; 185 } 186 187 uint32_t spdm_socket_rsp(const int socket, uint32_t transport_type, 188 void *req, uint32_t req_len, 189 void *rsp, uint32_t rsp_len) 190 { 191 uint32_t command; 192 bool result; 193 194 result = send_platform_data(socket, transport_type, 195 SPDM_SOCKET_COMMAND_NORMAL, 196 req, req_len); 197 if (!result) { 198 return 0; 199 } 200 201 result = receive_platform_data(socket, transport_type, &command, 202 (uint8_t *)rsp, &rsp_len); 203 if (!result) { 204 return 0; 205 } 206 207 assert(command != 0); 208 209 return rsp_len; 210 } 211 212 void spdm_socket_close(const int socket, uint32_t transport_type) 213 { 214 send_platform_data(socket, transport_type, 215 SPDM_SOCKET_COMMAND_SHUTDOWN, NULL, 0); 216 } 217