xref: /openbmc/dbus-sensors/src/nvidia-gpu/MctpRequester.cpp (revision 6ef897399d8c2fb9f005b0c0d0ba3f9bd115965f)
1 /*
2  * SPDX-FileCopyrightText: Copyright OpenBMC Authors
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #include "MctpRequester.hpp"
7 
8 #include "MctpAsioEndpoint.hpp"
9 
10 #include <sys/socket.h>
11 
12 #include <OcpMctpVdm.hpp>
13 #include <boost/asio/buffer.hpp>
14 #include <boost/asio/error.hpp>
15 #include <boost/asio/generic/datagram_protocol.hpp>
16 #include <boost/asio/io_context.hpp>
17 #include <boost/asio/steady_timer.hpp>
18 #include <boost/container/devector.hpp>
19 #include <phosphor-logging/lg2.hpp>
20 
21 #include <bit>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <expected>
26 #include <format>
27 #include <functional>
28 #include <optional>
29 #include <span>
30 #include <stdexcept>
31 #include <system_error>
32 #include <utility>
33 
34 using namespace std::literals;
35 
36 namespace mctp
37 {
38 
getHeaderFromBuffer(std::span<const uint8_t> buffer)39 static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer(
40     std::span<const uint8_t> buffer)
41 {
42     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
43     {
44         return nullptr;
45     }
46 
47     return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>(
48         buffer.data());
49 }
50 
getIid(std::span<const uint8_t> buffer)51 static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer)
52 {
53     const ocp::accelerator_management::BindingPciVid* header =
54         getHeaderFromBuffer(buffer);
55     if (header == nullptr)
56     {
57         return std::nullopt;
58     }
59     return header->instance_id & ocp::accelerator_management::instanceIdBitMask;
60 }
61 
getRequestBit(std::span<const uint8_t> buffer)62 static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer)
63 {
64     const ocp::accelerator_management::BindingPciVid* header =
65         getHeaderFromBuffer(buffer);
66     if (header == nullptr)
67     {
68         return std::nullopt;
69     }
70     return header->instance_id & ocp::accelerator_management::requestBitMask;
71 }
72 
MctpRequester(boost::asio::io_context & ctx)73 MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
74     io{ctx},
75     mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0})
76 {
77     startReceive();
78 }
79 
startReceive()80 void MctpRequester::startReceive()
81 {
82     mctpSocket.async_receive_from(
83         boost::asio::buffer(buffer), recvEndPoint.endpoint,
84         std::bind_front(&MctpRequester::processRecvMsg, this));
85 }
86 
processRecvMsg(const boost::system::error_code & ec,const size_t length)87 void MctpRequester::processRecvMsg(const boost::system::error_code& ec,
88                                    const size_t length)
89 {
90     std::optional<uint8_t> expectedEid = recvEndPoint.eid();
91     std::optional<uint8_t> receivedMsgType = recvEndPoint.type();
92 
93     if (!expectedEid || !receivedMsgType)
94     {
95         // we were handed an endpoint that can't be treated as an MCTP endpoint
96         // This is probably a kernel bug...yell about it and rebind.
97         lg2::error("MctpRequester: invalid endpoint");
98         return;
99     }
100 
101     if (*receivedMsgType != ocp::accelerator_management::messageType)
102     {
103         // we received a message that this handler doesn't support
104         // drop it on the floor and rebind receive_from
105         lg2::error("MctpRequester: Message type mismatch. We received {MSG}",
106                    "MSG", *receivedMsgType);
107         return;
108     }
109 
110     uint8_t eid = *expectedEid;
111 
112     if (ec)
113     {
114         lg2::error(
115             "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
116             "EC", ec.value(), "ER", ec.message());
117         handleResult(eid, static_cast<std::error_code>(ec), {});
118         return;
119     }
120 
121     // if the received length was greater than our buffer, we would've truncated
122     // and gotten an error code in asio
123     std::span<const uint8_t> responseBuffer{buffer.data(), length};
124 
125     std::optional<uint8_t> optionalIid = getIid(responseBuffer);
126     std::optional<bool> isRq = getRequestBit(responseBuffer);
127     if (!optionalIid || !isRq)
128     {
129         // we received something from the device,
130         // but we aren't able to parse iid byte
131         // drop this packet on the floor
132         // and rely on the timer to notify the client
133         lg2::error("MctpRequester: Unable to decode message from eid {EID}",
134                    "EID", eid);
135         return;
136     }
137 
138     if (isRq.value())
139     {
140         // we received a request from a downstream device.
141         // We don't currently support this, drop the packet
142         // on the floor and rebind receive, keep the timer running
143         return;
144     }
145 
146     uint8_t iid = *optionalIid;
147 
148     auto it = requestContextQueues.find(eid);
149     if (it == requestContextQueues.end())
150     {
151         // something very bad has happened here
152         // we've received a packet that is a response
153         // from a device we've never talked to
154         // do our best and rebind receive and keep the timer running
155         lg2::error("Unable to match request to response");
156         return;
157     }
158 
159     if (iid != it->second.iid)
160     {
161         // we received an iid that doesn't match the one we sent
162         // rebind async_receive_from and drop this packet on the floor
163         lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID",
164                    iid, "EID", eid, "E_IID", it->second.iid);
165         return;
166     }
167 
168     handleResult(eid, std::error_code{}, responseBuffer);
169 }
170 
handleSendMsgCompletion(uint8_t eid,const boost::system::error_code & ec,size_t)171 void MctpRequester::handleSendMsgCompletion(
172     uint8_t eid, const boost::system::error_code& ec, size_t /* length */)
173 {
174     if (ec)
175     {
176         lg2::error(
177             "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
178             "EC", ec.value(), "ER", ec.message());
179         handleResult(eid, static_cast<std::error_code>(ec), {});
180         return;
181     }
182 
183     auto it = requestContextQueues.find(eid);
184     if (it == requestContextQueues.end())
185     {
186         // something very bad has happened here,
187         // we've sent something to a device that we have
188         // no record of. yell loudly and bail
189         lg2::error(
190             "MctpRequester completed send for an EID that we have no record of");
191         return;
192     }
193 
194     boost::asio::steady_timer& expiryTimer = it->second.timer;
195     expiryTimer.expires_after(2s);
196 
197     expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
198         if (ec != boost::asio::error::operation_aborted)
199         {
200             lg2::error("Operation timed out on eid {EID}", "EID", eid);
201             handleResult(eid, std::make_error_code(std::errc::timed_out), {});
202         }
203     });
204 }
205 
sendRecvMsg(uint8_t eid,std::span<const uint8_t> reqMsg,std::move_only_function<void (const std::error_code &,std::span<const uint8_t>)> callback)206 void MctpRequester::sendRecvMsg(
207     uint8_t eid, std::span<const uint8_t> reqMsg,
208     std::move_only_function<void(const std::error_code&,
209                                  std::span<const uint8_t>)>
210         callback)
211 {
212     RequestContext reqCtx{reqMsg, std::move(callback)};
213 
214     // try_emplace only affects the result if the key does not already exist
215     auto [it, inserted] = requestContextQueues.try_emplace(eid, io);
216     (void)inserted;
217 
218     auto& queue = it->second.queue;
219     queue.push_back(std::move(reqCtx));
220 
221     if (queue.size() == 1)
222     {
223         processQueue(eid);
224     }
225 }
226 
isFatalError(const std::error_code & ec)227 static bool isFatalError(const std::error_code& ec)
228 {
229     return ec &&
230            (ec != std::errc::timed_out && ec != std::errc::host_unreachable);
231 }
232 
handleResult(uint8_t eid,const std::error_code & ec,std::span<const uint8_t> buffer)233 void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec,
234                                  std::span<const uint8_t> buffer)
235 {
236     auto it = requestContextQueues.find(eid);
237     if (it == requestContextQueues.end())
238     {
239         lg2::error("We tried to a handle a result for an eid we don't have");
240 
241         startReceive();
242         return;
243     }
244 
245     auto& queue = it->second.queue;
246     auto& reqCtx = queue.front();
247 
248     it->second.timer.cancel();
249 
250     reqCtx.callback(ec, buffer); // Call the original callback
251 
252     if (isFatalError(ec))
253     {
254         // some errors are fatal, since these are datagrams,
255         // we won't get a receive path error message.
256         // and since this daemon services all nvidia iana commands
257         // for a given system, we should only restart the service if its
258         // unrecoverable, i.e. if we get error codes that the client
259         // can't reasonably deal with. If thats the cause, restart
260         // and hope that we can deal with it then.
261         // since we're fully async, the only reasonable way to bubble
262         // this issue up is to chuck an exception and let main deal with it.
263         // alternatively we could call cancel on the io_context, but there's
264         // not a great way to figure *what* happened.
265         throw std::runtime_error(std::format(
266             "eid {} encountered a fatal error: {}", eid, ec.message()));
267     }
268 
269     startReceive();
270 
271     queue.pop_front();
272 
273     processQueue(eid);
274 }
275 
getNextIid(uint8_t eid)276 std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid)
277 {
278     auto it = requestContextQueues.find(eid);
279     if (it == requestContextQueues.end())
280     {
281         return std::nullopt;
282     }
283 
284     uint8_t& iid = it->second.iid;
285     ++iid;
286     iid &= ocp::accelerator_management::instanceIdBitMask;
287     return iid;
288 }
289 
injectIid(std::span<uint8_t> buffer,uint8_t iid)290 static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer,
291                                                       uint8_t iid)
292 {
293     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
294     {
295         return std::unexpected(
296             std::make_error_code(std::errc::invalid_argument));
297     }
298 
299     if (iid > ocp::accelerator_management::instanceIdBitMask)
300     {
301         return std::unexpected(
302             std::make_error_code(std::errc::invalid_argument));
303     }
304 
305     auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>(
306         buffer.data());
307 
308     header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask;
309     header->instance_id |= iid;
310     return {};
311 }
312 
processQueue(uint8_t eid)313 void MctpRequester::processQueue(uint8_t eid)
314 {
315     auto it = requestContextQueues.find(eid);
316     if (it == requestContextQueues.end())
317     {
318         lg2::error("We are attempting to process a queue that doesn't exist");
319         return;
320     }
321 
322     auto& queue = it->second.queue;
323 
324     if (queue.empty())
325     {
326         return;
327     }
328     auto& reqCtx = queue.front();
329 
330     std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()};
331 
332     std::optional<uint8_t> iid = getNextIid(eid);
333     if (!iid)
334     {
335         lg2::error("MctpRequester: Unable to get next iid");
336         handleResult(eid, std::make_error_code(std::errc::no_such_device), {});
337         return;
338     }
339 
340     std::expected<void, std::error_code> success = injectIid(req, *iid);
341     if (!success)
342     {
343         lg2::error("MctpRequester: unable to set iid");
344         handleResult(eid, success.error(), {});
345         return;
346     }
347 
348     MctpAsioEndpoint sendEndPoint(eid,
349                                   ocp::accelerator_management::messageType);
350     boost::asio::const_buffer buf(req.data(), req.size());
351     mctpSocket.async_send_to(
352         buf, sendEndPoint.endpoint,
353         std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid));
354 }
355 
356 } // namespace mctp
357