1 /*
2 * SPDX-FileCopyrightText: Copyright OpenBMC Authors
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6 #include "MctpRequester.hpp"
7
8 #include <linux/mctp.h>
9 #include <sys/socket.h>
10
11 #include <OcpMctpVdm.hpp>
12 #include <boost/asio/buffer.hpp>
13 #include <boost/asio/error.hpp>
14 #include <boost/asio/generic/datagram_protocol.hpp>
15 #include <boost/asio/io_context.hpp>
16 #include <boost/asio/steady_timer.hpp>
17 #include <boost/container/devector.hpp>
18 #include <phosphor-logging/lg2.hpp>
19
20 #include <bit>
21 #include <cstddef>
22 #include <cstdint>
23 #include <cstring>
24 #include <expected>
25 #include <format>
26 #include <functional>
27 #include <optional>
28 #include <span>
29 #include <stdexcept>
30 #include <system_error>
31 #include <utility>
32
33 using namespace std::literals;
34
35 namespace mctp
36 {
37
getHeaderFromBuffer(std::span<const uint8_t> buffer)38 static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer(
39 std::span<const uint8_t> buffer)
40 {
41 if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
42 {
43 return nullptr;
44 }
45
46 return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>(
47 buffer.data());
48 }
49
getIid(std::span<const uint8_t> buffer)50 static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer)
51 {
52 const ocp::accelerator_management::BindingPciVid* header =
53 getHeaderFromBuffer(buffer);
54 if (header == nullptr)
55 {
56 return std::nullopt;
57 }
58 return header->instance_id & ocp::accelerator_management::instanceIdBitMask;
59 }
60
getRequestBit(std::span<const uint8_t> buffer)61 static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer)
62 {
63 const ocp::accelerator_management::BindingPciVid* header =
64 getHeaderFromBuffer(buffer);
65 if (header == nullptr)
66 {
67 return std::nullopt;
68 }
69 return header->instance_id & ocp::accelerator_management::requestBitMask;
70 }
71
MctpRequester(boost::asio::io_context & ctx)72 MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
73 io{ctx},
74 mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0})
75 {
76 startReceive();
77 }
78
startReceive()79 void MctpRequester::startReceive()
80 {
81 mctpSocket.async_receive_from(
82 boost::asio::buffer(buffer), recvEndPoint.endpoint,
83 std::bind_front(&MctpRequester::processRecvMsg, this));
84 }
85
processRecvMsg(const boost::system::error_code & ec,const size_t length)86 void MctpRequester::processRecvMsg(const boost::system::error_code& ec,
87 const size_t length)
88 {
89 std::optional<uint8_t> expectedEid = recvEndPoint.eid();
90 std::optional<uint8_t> receivedMsgType = recvEndPoint.type();
91
92 if (!expectedEid || !receivedMsgType)
93 {
94 // we were handed an endpoint that can't be treated as an MCTP endpoint
95 // This is probably a kernel bug...yell about it and rebind.
96 lg2::error("MctpRequester: invalid endpoint");
97 return;
98 }
99
100 if (*receivedMsgType != msgType)
101 {
102 // we received a message that this handler doesn't support
103 // drop it on the floor and rebind receive_from
104 lg2::error("MctpRequester: Message type mismatch. We received {MSG}",
105 "MSG", *receivedMsgType);
106 return;
107 }
108
109 uint8_t eid = *expectedEid;
110
111 if (ec)
112 {
113 lg2::error(
114 "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
115 "EC", ec.value(), "ER", ec.message());
116 handleResult(eid, static_cast<std::error_code>(ec), {});
117 return;
118 }
119
120 // if the received length was greater than our buffer, we would've truncated
121 // and gotten an error code in asio
122 std::span<const uint8_t> responseBuffer{buffer.data(), length};
123
124 std::optional<uint8_t> optionalIid = getIid(responseBuffer);
125 std::optional<bool> isRq = getRequestBit(responseBuffer);
126 if (!optionalIid || !isRq)
127 {
128 // we received something from the device,
129 // but we aren't able to parse iid byte
130 // drop this packet on the floor
131 // and rely on the timer to notify the client
132 lg2::error("MctpRequester: Unable to decode message from eid {EID}",
133 "EID", eid);
134 return;
135 }
136
137 if (isRq.value())
138 {
139 // we received a request from a downstream device.
140 // We don't currently support this, drop the packet
141 // on the floor and rebind receive, keep the timer running
142 return;
143 }
144
145 uint8_t iid = *optionalIid;
146
147 auto it = requestContextQueues.find(eid);
148 if (it == requestContextQueues.end())
149 {
150 // something very bad has happened here
151 // we've received a packet that is a response
152 // from a device we've never talked to
153 // do our best and rebind receive and keep the timer running
154 lg2::error("Unable to match request to response");
155 return;
156 }
157
158 if (iid != it->second.iid)
159 {
160 // we received an iid that doesn't match the one we sent
161 // rebind async_receive_from and drop this packet on the floor
162 lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID",
163 iid, "EID", eid, "E_IID", it->second.iid);
164 return;
165 }
166
167 handleResult(eid, std::error_code{}, responseBuffer);
168 }
169
handleSendMsgCompletion(uint8_t eid,const boost::system::error_code & ec,size_t)170 void MctpRequester::handleSendMsgCompletion(
171 uint8_t eid, const boost::system::error_code& ec, size_t /* length */)
172 {
173 if (ec)
174 {
175 lg2::error(
176 "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
177 "EC", ec.value(), "ER", ec.message());
178 handleResult(eid, static_cast<std::error_code>(ec), {});
179 return;
180 }
181
182 auto it = requestContextQueues.find(eid);
183 if (it == requestContextQueues.end())
184 {
185 // something very bad has happened here,
186 // we've sent something to a device that we have
187 // no record of. yell loudly and bail
188 lg2::error(
189 "MctpRequester completed send for an EID that we have no record of");
190 return;
191 }
192
193 boost::asio::steady_timer& expiryTimer = it->second.timer;
194 expiryTimer.expires_after(2s);
195
196 expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
197 if (ec != boost::asio::error::operation_aborted)
198 {
199 lg2::error("Operation timed out on eid {EID}", "EID", eid);
200 handleResult(eid, std::make_error_code(std::errc::timed_out), {});
201 }
202 });
203 }
204
sendRecvMsg(uint8_t eid,std::span<const uint8_t> reqMsg,std::move_only_function<void (const std::error_code &,std::span<const uint8_t>)> callback)205 void MctpRequester::sendRecvMsg(
206 uint8_t eid, std::span<const uint8_t> reqMsg,
207 std::move_only_function<void(const std::error_code&,
208 std::span<const uint8_t>)>
209 callback)
210 {
211 RequestContext reqCtx{reqMsg, std::move(callback)};
212
213 // try_emplace only affects the result if the key does not already exist
214 auto [it, inserted] = requestContextQueues.try_emplace(eid, io);
215 (void)inserted;
216
217 auto& queue = it->second.queue;
218 queue.push_back(std::move(reqCtx));
219
220 if (queue.size() == 1)
221 {
222 processQueue(eid);
223 }
224 }
225
isFatalError(const std::error_code & ec)226 static bool isFatalError(const std::error_code& ec)
227 {
228 return ec &&
229 (ec != std::errc::timed_out && ec != std::errc::host_unreachable);
230 }
231
handleResult(uint8_t eid,const std::error_code & ec,std::span<const uint8_t> buffer)232 void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec,
233 std::span<const uint8_t> buffer)
234 {
235 auto it = requestContextQueues.find(eid);
236 if (it == requestContextQueues.end())
237 {
238 lg2::error("We tried to a handle a result for an eid we don't have");
239
240 startReceive();
241 return;
242 }
243
244 auto& queue = it->second.queue;
245 auto& reqCtx = queue.front();
246
247 it->second.timer.cancel();
248
249 reqCtx.callback(ec, buffer); // Call the original callback
250
251 if (isFatalError(ec))
252 {
253 // some errors are fatal, since these are datagrams,
254 // we won't get a receive path error message.
255 // and since this daemon services all nvidia iana commands
256 // for a given system, we should only restart the service if its
257 // unrecoverable, i.e. if we get error codes that the client
258 // can't reasonably deal with. If thats the cause, restart
259 // and hope that we can deal with it then.
260 // since we're fully async, the only reasonable way to bubble
261 // this issue up is to chuck an exception and let main deal with it.
262 // alternatively we could call cancel on the io_context, but there's
263 // not a great way to figure *what* happened.
264 throw std::runtime_error(std::format(
265 "eid {} encountered a fatal error: {}", eid, ec.message()));
266 }
267
268 startReceive();
269
270 queue.pop_front();
271
272 processQueue(eid);
273 }
274
getNextIid(uint8_t eid)275 std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid)
276 {
277 auto it = requestContextQueues.find(eid);
278 if (it == requestContextQueues.end())
279 {
280 return std::nullopt;
281 }
282
283 uint8_t& iid = it->second.iid;
284 ++iid;
285 iid &= ocp::accelerator_management::instanceIdBitMask;
286 return iid;
287 }
288
injectIid(std::span<uint8_t> buffer,uint8_t iid)289 static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer,
290 uint8_t iid)
291 {
292 if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
293 {
294 return std::unexpected(
295 std::make_error_code(std::errc::invalid_argument));
296 }
297
298 if (iid > ocp::accelerator_management::instanceIdBitMask)
299 {
300 return std::unexpected(
301 std::make_error_code(std::errc::invalid_argument));
302 }
303
304 auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>(
305 buffer.data());
306
307 header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask;
308 header->instance_id |= iid;
309 return {};
310 }
311
processQueue(uint8_t eid)312 void MctpRequester::processQueue(uint8_t eid)
313 {
314 auto it = requestContextQueues.find(eid);
315 if (it == requestContextQueues.end())
316 {
317 lg2::error("We are attempting to process a queue that doesn't exist");
318 return;
319 }
320
321 auto& queue = it->second.queue;
322
323 if (queue.empty())
324 {
325 return;
326 }
327 auto& reqCtx = queue.front();
328
329 std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()};
330
331 std::optional<uint8_t> iid = getNextIid(eid);
332 if (!iid)
333 {
334 lg2::error("MctpRequester: Unable to get next iid");
335 handleResult(eid, std::make_error_code(std::errc::no_such_device), {});
336 return;
337 }
338
339 std::expected<void, std::error_code> success = injectIid(req, *iid);
340 if (!success)
341 {
342 lg2::error("MctpRequester: unable to set iid");
343 handleResult(eid, success.error(), {});
344 return;
345 }
346
347 struct sockaddr_mctp addr{};
348 addr.smctp_family = AF_MCTP;
349 addr.smctp_addr.s_addr = eid;
350 addr.smctp_type = msgType;
351 addr.smctp_tag = MCTP_TAG_OWNER;
352 using endpoint = boost::asio::generic::datagram_protocol::endpoint;
353 endpoint sendEndPoint{&addr, sizeof(addr)};
354
355 mctpSocket.async_send_to(
356 boost::asio::const_buffer(req.data(), req.size()), sendEndPoint,
357 std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid));
358 }
359
360 } // namespace mctp
361