1 /*
2 * SPDX-FileCopyrightText: Copyright OpenBMC Authors
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6 #include "MctpRequester.hpp"
7
8 #include "MctpAsioEndpoint.hpp"
9
10 #include <sys/socket.h>
11
12 #include <OcpMctpVdm.hpp>
13 #include <boost/asio/buffer.hpp>
14 #include <boost/asio/error.hpp>
15 #include <boost/asio/generic/datagram_protocol.hpp>
16 #include <boost/asio/io_context.hpp>
17 #include <boost/asio/steady_timer.hpp>
18 #include <boost/container/devector.hpp>
19 #include <phosphor-logging/lg2.hpp>
20
21 #include <bit>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <expected>
26 #include <format>
27 #include <functional>
28 #include <optional>
29 #include <span>
30 #include <stdexcept>
31 #include <system_error>
32 #include <utility>
33
34 using namespace std::literals;
35
36 namespace mctp
37 {
38
getHeaderFromBuffer(std::span<const uint8_t> buffer)39 static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer(
40 std::span<const uint8_t> buffer)
41 {
42 if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
43 {
44 return nullptr;
45 }
46
47 return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>(
48 buffer.data());
49 }
50
getIid(std::span<const uint8_t> buffer)51 static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer)
52 {
53 const ocp::accelerator_management::BindingPciVid* header =
54 getHeaderFromBuffer(buffer);
55 if (header == nullptr)
56 {
57 return std::nullopt;
58 }
59 return header->instance_id & ocp::accelerator_management::instanceIdBitMask;
60 }
61
getRequestBit(std::span<const uint8_t> buffer)62 static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer)
63 {
64 const ocp::accelerator_management::BindingPciVid* header =
65 getHeaderFromBuffer(buffer);
66 if (header == nullptr)
67 {
68 return std::nullopt;
69 }
70 return header->instance_id & ocp::accelerator_management::requestBitMask;
71 }
72
MctpRequester(boost::asio::io_context & ctx)73 MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
74 io{ctx},
75 mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0})
76 {
77 startReceive();
78 }
79
startReceive()80 void MctpRequester::startReceive()
81 {
82 mctpSocket.async_receive_from(
83 boost::asio::buffer(buffer), recvEndPoint.endpoint,
84 std::bind_front(&MctpRequester::processRecvMsg, this));
85 }
86
processRecvMsg(const boost::system::error_code & ec,const size_t length)87 void MctpRequester::processRecvMsg(const boost::system::error_code& ec,
88 const size_t length)
89 {
90 std::optional<uint8_t> expectedEid = recvEndPoint.eid();
91 std::optional<uint8_t> receivedMsgType = recvEndPoint.type();
92
93 if (!expectedEid || !receivedMsgType)
94 {
95 // we were handed an endpoint that can't be treated as an MCTP endpoint
96 // This is probably a kernel bug...yell about it and rebind.
97 lg2::error("MctpRequester: invalid endpoint");
98 return;
99 }
100
101 if (*receivedMsgType != ocp::accelerator_management::messageType)
102 {
103 // we received a message that this handler doesn't support
104 // drop it on the floor and rebind receive_from
105 lg2::error("MctpRequester: Message type mismatch. We received {MSG}",
106 "MSG", *receivedMsgType);
107 return;
108 }
109
110 uint8_t eid = *expectedEid;
111
112 if (ec)
113 {
114 lg2::error(
115 "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
116 "EC", ec.value(), "ER", ec.message());
117 handleResult(eid, static_cast<std::error_code>(ec), {});
118 return;
119 }
120
121 // if the received length was greater than our buffer, we would've truncated
122 // and gotten an error code in asio
123 std::span<const uint8_t> responseBuffer{buffer.data(), length};
124
125 std::optional<uint8_t> optionalIid = getIid(responseBuffer);
126 std::optional<bool> isRq = getRequestBit(responseBuffer);
127 if (!optionalIid || !isRq)
128 {
129 // we received something from the device,
130 // but we aren't able to parse iid byte
131 // drop this packet on the floor
132 // and rely on the timer to notify the client
133 lg2::error("MctpRequester: Unable to decode message from eid {EID}",
134 "EID", eid);
135 return;
136 }
137
138 if (isRq.value())
139 {
140 // we received a request from a downstream device.
141 // We don't currently support this, drop the packet
142 // on the floor and rebind receive, keep the timer running
143 return;
144 }
145
146 uint8_t iid = *optionalIid;
147
148 auto it = requestContextQueues.find(eid);
149 if (it == requestContextQueues.end())
150 {
151 // something very bad has happened here
152 // we've received a packet that is a response
153 // from a device we've never talked to
154 // do our best and rebind receive and keep the timer running
155 lg2::error("Unable to match request to response");
156 return;
157 }
158
159 if (iid != it->second.iid)
160 {
161 // we received an iid that doesn't match the one we sent
162 // rebind async_receive_from and drop this packet on the floor
163 lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID",
164 iid, "EID", eid, "E_IID", it->second.iid);
165 return;
166 }
167
168 handleResult(eid, std::error_code{}, responseBuffer);
169 }
170
handleSendMsgCompletion(uint8_t eid,const boost::system::error_code & ec,size_t)171 void MctpRequester::handleSendMsgCompletion(
172 uint8_t eid, const boost::system::error_code& ec, size_t /* length */)
173 {
174 if (ec)
175 {
176 lg2::error(
177 "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
178 "EC", ec.value(), "ER", ec.message());
179 handleResult(eid, static_cast<std::error_code>(ec), {});
180 return;
181 }
182
183 auto it = requestContextQueues.find(eid);
184 if (it == requestContextQueues.end())
185 {
186 // something very bad has happened here,
187 // we've sent something to a device that we have
188 // no record of. yell loudly and bail
189 lg2::error(
190 "MctpRequester completed send for an EID that we have no record of");
191 return;
192 }
193
194 boost::asio::steady_timer& expiryTimer = it->second.timer;
195 expiryTimer.expires_after(2s);
196
197 expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
198 if (ec != boost::asio::error::operation_aborted)
199 {
200 lg2::error("Operation timed out on eid {EID}", "EID", eid);
201 handleResult(eid, std::make_error_code(std::errc::timed_out), {});
202 }
203 });
204 }
205
sendRecvMsg(uint8_t eid,std::span<const uint8_t> reqMsg,std::move_only_function<void (const std::error_code &,std::span<const uint8_t>)> callback)206 void MctpRequester::sendRecvMsg(
207 uint8_t eid, std::span<const uint8_t> reqMsg,
208 std::move_only_function<void(const std::error_code&,
209 std::span<const uint8_t>)>
210 callback)
211 {
212 RequestContext reqCtx{reqMsg, std::move(callback)};
213
214 // try_emplace only affects the result if the key does not already exist
215 auto [it, inserted] = requestContextQueues.try_emplace(eid, io);
216 (void)inserted;
217
218 auto& queue = it->second.queue;
219 queue.push_back(std::move(reqCtx));
220
221 if (queue.size() == 1)
222 {
223 processQueue(eid);
224 }
225 }
226
isFatalError(const std::error_code & ec)227 static bool isFatalError(const std::error_code& ec)
228 {
229 return ec &&
230 (ec != std::errc::timed_out && ec != std::errc::host_unreachable);
231 }
232
handleResult(uint8_t eid,const std::error_code & ec,std::span<const uint8_t> buffer)233 void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec,
234 std::span<const uint8_t> buffer)
235 {
236 auto it = requestContextQueues.find(eid);
237 if (it == requestContextQueues.end())
238 {
239 lg2::error("We tried to a handle a result for an eid we don't have");
240
241 startReceive();
242 return;
243 }
244
245 auto& queue = it->second.queue;
246 auto& reqCtx = queue.front();
247
248 it->second.timer.cancel();
249
250 reqCtx.callback(ec, buffer); // Call the original callback
251
252 if (isFatalError(ec))
253 {
254 // some errors are fatal, since these are datagrams,
255 // we won't get a receive path error message.
256 // and since this daemon services all nvidia iana commands
257 // for a given system, we should only restart the service if its
258 // unrecoverable, i.e. if we get error codes that the client
259 // can't reasonably deal with. If thats the cause, restart
260 // and hope that we can deal with it then.
261 // since we're fully async, the only reasonable way to bubble
262 // this issue up is to chuck an exception and let main deal with it.
263 // alternatively we could call cancel on the io_context, but there's
264 // not a great way to figure *what* happened.
265 throw std::runtime_error(std::format(
266 "eid {} encountered a fatal error: {}", eid, ec.message()));
267 }
268
269 startReceive();
270
271 queue.pop_front();
272
273 processQueue(eid);
274 }
275
getNextIid(uint8_t eid)276 std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid)
277 {
278 auto it = requestContextQueues.find(eid);
279 if (it == requestContextQueues.end())
280 {
281 return std::nullopt;
282 }
283
284 uint8_t& iid = it->second.iid;
285 ++iid;
286 iid &= ocp::accelerator_management::instanceIdBitMask;
287 return iid;
288 }
289
injectIid(std::span<uint8_t> buffer,uint8_t iid)290 static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer,
291 uint8_t iid)
292 {
293 if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
294 {
295 return std::unexpected(
296 std::make_error_code(std::errc::invalid_argument));
297 }
298
299 if (iid > ocp::accelerator_management::instanceIdBitMask)
300 {
301 return std::unexpected(
302 std::make_error_code(std::errc::invalid_argument));
303 }
304
305 auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>(
306 buffer.data());
307
308 header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask;
309 header->instance_id |= iid;
310 return {};
311 }
312
processQueue(uint8_t eid)313 void MctpRequester::processQueue(uint8_t eid)
314 {
315 auto it = requestContextQueues.find(eid);
316 if (it == requestContextQueues.end())
317 {
318 lg2::error("We are attempting to process a queue that doesn't exist");
319 return;
320 }
321
322 auto& queue = it->second.queue;
323
324 if (queue.empty())
325 {
326 return;
327 }
328 auto& reqCtx = queue.front();
329
330 std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()};
331
332 std::optional<uint8_t> iid = getNextIid(eid);
333 if (!iid)
334 {
335 lg2::error("MctpRequester: Unable to get next iid");
336 handleResult(eid, std::make_error_code(std::errc::no_such_device), {});
337 return;
338 }
339
340 std::expected<void, std::error_code> success = injectIid(req, *iid);
341 if (!success)
342 {
343 lg2::error("MctpRequester: unable to set iid");
344 handleResult(eid, success.error(), {});
345 return;
346 }
347
348 MctpAsioEndpoint sendEndPoint(eid,
349 ocp::accelerator_management::messageType);
350 boost::asio::const_buffer buf(req.data(), req.size());
351 mctpSocket.async_send_to(
352 buf, sendEndPoint.endpoint,
353 std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid));
354 }
355
356 } // namespace mctp
357