xref: /openbmc/dbus-sensors/src/nvidia-gpu/MctpRequester.cpp (revision 779d84f06b79625060d65af42b2fdd4e141ade48)
1 /*
2  * SPDX-FileCopyrightText: Copyright OpenBMC Authors
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #include "MctpRequester.hpp"
7 
8 #include <linux/mctp.h>
9 #include <sys/socket.h>
10 
11 #include <OcpMctpVdm.hpp>
12 #include <boost/asio/buffer.hpp>
13 #include <boost/asio/error.hpp>
14 #include <boost/asio/generic/datagram_protocol.hpp>
15 #include <boost/asio/io_context.hpp>
16 #include <boost/asio/steady_timer.hpp>
17 #include <boost/container/devector.hpp>
18 #include <phosphor-logging/lg2.hpp>
19 
20 #include <bit>
21 #include <cstddef>
22 #include <cstdint>
23 #include <cstring>
24 #include <expected>
25 #include <format>
26 #include <functional>
27 #include <optional>
28 #include <span>
29 #include <stdexcept>
30 #include <system_error>
31 #include <utility>
32 
33 using namespace std::literals;
34 
35 namespace mctp
36 {
37 
getHeaderFromBuffer(std::span<const uint8_t> buffer)38 static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer(
39     std::span<const uint8_t> buffer)
40 {
41     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
42     {
43         return nullptr;
44     }
45 
46     return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>(
47         buffer.data());
48 }
49 
getIid(std::span<const uint8_t> buffer)50 static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer)
51 {
52     const ocp::accelerator_management::BindingPciVid* header =
53         getHeaderFromBuffer(buffer);
54     if (header == nullptr)
55     {
56         return std::nullopt;
57     }
58     return header->instance_id & ocp::accelerator_management::instanceIdBitMask;
59 }
60 
getRequestBit(std::span<const uint8_t> buffer)61 static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer)
62 {
63     const ocp::accelerator_management::BindingPciVid* header =
64         getHeaderFromBuffer(buffer);
65     if (header == nullptr)
66     {
67         return std::nullopt;
68     }
69     return header->instance_id & ocp::accelerator_management::requestBitMask;
70 }
71 
MctpRequester(boost::asio::io_context & ctx)72 MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
73     io{ctx},
74     mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0})
75 {
76     startReceive();
77 }
78 
startReceive()79 void MctpRequester::startReceive()
80 {
81     mctpSocket.async_receive_from(
82         boost::asio::buffer(buffer), recvEndPoint.endpoint,
83         std::bind_front(&MctpRequester::processRecvMsg, this));
84 }
85 
processRecvMsg(const boost::system::error_code & ec,const size_t length)86 void MctpRequester::processRecvMsg(const boost::system::error_code& ec,
87                                    const size_t length)
88 {
89     std::optional<uint8_t> expectedEid = recvEndPoint.eid();
90     std::optional<uint8_t> receivedMsgType = recvEndPoint.type();
91 
92     if (!expectedEid || !receivedMsgType)
93     {
94         // we were handed an endpoint that can't be treated as an MCTP endpoint
95         // This is probably a kernel bug...yell about it and rebind.
96         lg2::error("MctpRequester: invalid endpoint");
97         return;
98     }
99 
100     if (*receivedMsgType != msgType)
101     {
102         // we received a message that this handler doesn't support
103         // drop it on the floor and rebind receive_from
104         lg2::error("MctpRequester: Message type mismatch. We received {MSG}",
105                    "MSG", *receivedMsgType);
106         return;
107     }
108 
109     uint8_t eid = *expectedEid;
110 
111     if (ec)
112     {
113         lg2::error(
114             "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
115             "EC", ec.value(), "ER", ec.message());
116         handleResult(eid, static_cast<std::error_code>(ec), {});
117         return;
118     }
119 
120     // if the received length was greater than our buffer, we would've truncated
121     // and gotten an error code in asio
122     std::span<const uint8_t> responseBuffer{buffer.data(), length};
123 
124     std::optional<uint8_t> optionalIid = getIid(responseBuffer);
125     std::optional<bool> isRq = getRequestBit(responseBuffer);
126     if (!optionalIid || !isRq)
127     {
128         // we received something from the device,
129         // but we aren't able to parse iid byte
130         // drop this packet on the floor
131         // and rely on the timer to notify the client
132         lg2::error("MctpRequester: Unable to decode message from eid {EID}",
133                    "EID", eid);
134         return;
135     }
136 
137     if (isRq.value())
138     {
139         // we received a request from a downstream device.
140         // We don't currently support this, drop the packet
141         // on the floor and rebind receive, keep the timer running
142         return;
143     }
144 
145     uint8_t iid = *optionalIid;
146 
147     auto it = requestContextQueues.find(eid);
148     if (it == requestContextQueues.end())
149     {
150         // something very bad has happened here
151         // we've received a packet that is a response
152         // from a device we've never talked to
153         // do our best and rebind receive and keep the timer running
154         lg2::error("Unable to match request to response");
155         return;
156     }
157 
158     if (iid != it->second.iid)
159     {
160         // we received an iid that doesn't match the one we sent
161         // rebind async_receive_from and drop this packet on the floor
162         lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID",
163                    iid, "EID", eid, "E_IID", it->second.iid);
164         return;
165     }
166 
167     handleResult(eid, std::error_code{}, responseBuffer);
168 }
169 
handleSendMsgCompletion(uint8_t eid,const boost::system::error_code & ec,size_t)170 void MctpRequester::handleSendMsgCompletion(
171     uint8_t eid, const boost::system::error_code& ec, size_t /* length */)
172 {
173     if (ec)
174     {
175         lg2::error(
176             "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
177             "EC", ec.value(), "ER", ec.message());
178         handleResult(eid, static_cast<std::error_code>(ec), {});
179         return;
180     }
181 
182     auto it = requestContextQueues.find(eid);
183     if (it == requestContextQueues.end())
184     {
185         // something very bad has happened here,
186         // we've sent something to a device that we have
187         // no record of. yell loudly and bail
188         lg2::error(
189             "MctpRequester completed send for an EID that we have no record of");
190         return;
191     }
192 
193     boost::asio::steady_timer& expiryTimer = it->second.timer;
194     expiryTimer.expires_after(2s);
195 
196     expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
197         if (ec != boost::asio::error::operation_aborted)
198         {
199             lg2::error("Operation timed out on eid {EID}", "EID", eid);
200             handleResult(eid, std::make_error_code(std::errc::timed_out), {});
201         }
202     });
203 }
204 
sendRecvMsg(uint8_t eid,std::span<const uint8_t> reqMsg,std::move_only_function<void (const std::error_code &,std::span<const uint8_t>)> callback)205 void MctpRequester::sendRecvMsg(
206     uint8_t eid, std::span<const uint8_t> reqMsg,
207     std::move_only_function<void(const std::error_code&,
208                                  std::span<const uint8_t>)>
209         callback)
210 {
211     RequestContext reqCtx{reqMsg, std::move(callback)};
212 
213     // try_emplace only affects the result if the key does not already exist
214     auto [it, inserted] = requestContextQueues.try_emplace(eid, io);
215     (void)inserted;
216 
217     auto& queue = it->second.queue;
218     queue.push_back(std::move(reqCtx));
219 
220     if (queue.size() == 1)
221     {
222         processQueue(eid);
223     }
224 }
225 
isFatalError(const std::error_code & ec)226 static bool isFatalError(const std::error_code& ec)
227 {
228     return ec &&
229            (ec != std::errc::timed_out && ec != std::errc::host_unreachable);
230 }
231 
handleResult(uint8_t eid,const std::error_code & ec,std::span<const uint8_t> buffer)232 void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec,
233                                  std::span<const uint8_t> buffer)
234 {
235     auto it = requestContextQueues.find(eid);
236     if (it == requestContextQueues.end())
237     {
238         lg2::error("We tried to a handle a result for an eid we don't have");
239 
240         startReceive();
241         return;
242     }
243 
244     auto& queue = it->second.queue;
245     auto& reqCtx = queue.front();
246 
247     it->second.timer.cancel();
248 
249     reqCtx.callback(ec, buffer); // Call the original callback
250 
251     if (isFatalError(ec))
252     {
253         // some errors are fatal, since these are datagrams,
254         // we won't get a receive path error message.
255         // and since this daemon services all nvidia iana commands
256         // for a given system, we should only restart the service if its
257         // unrecoverable, i.e. if we get error codes that the client
258         // can't reasonably deal with. If thats the cause, restart
259         // and hope that we can deal with it then.
260         // since we're fully async, the only reasonable way to bubble
261         // this issue up is to chuck an exception and let main deal with it.
262         // alternatively we could call cancel on the io_context, but there's
263         // not a great way to figure *what* happened.
264         throw std::runtime_error(std::format(
265             "eid {} encountered a fatal error: {}", eid, ec.message()));
266     }
267 
268     startReceive();
269 
270     queue.pop_front();
271 
272     processQueue(eid);
273 }
274 
getNextIid(uint8_t eid)275 std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid)
276 {
277     auto it = requestContextQueues.find(eid);
278     if (it == requestContextQueues.end())
279     {
280         return std::nullopt;
281     }
282 
283     uint8_t& iid = it->second.iid;
284     ++iid;
285     iid &= ocp::accelerator_management::instanceIdBitMask;
286     return iid;
287 }
288 
injectIid(std::span<uint8_t> buffer,uint8_t iid)289 static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer,
290                                                       uint8_t iid)
291 {
292     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
293     {
294         return std::unexpected(
295             std::make_error_code(std::errc::invalid_argument));
296     }
297 
298     if (iid > ocp::accelerator_management::instanceIdBitMask)
299     {
300         return std::unexpected(
301             std::make_error_code(std::errc::invalid_argument));
302     }
303 
304     auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>(
305         buffer.data());
306 
307     header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask;
308     header->instance_id |= iid;
309     return {};
310 }
311 
processQueue(uint8_t eid)312 void MctpRequester::processQueue(uint8_t eid)
313 {
314     auto it = requestContextQueues.find(eid);
315     if (it == requestContextQueues.end())
316     {
317         lg2::error("We are attempting to process a queue that doesn't exist");
318         return;
319     }
320 
321     auto& queue = it->second.queue;
322 
323     if (queue.empty())
324     {
325         return;
326     }
327     auto& reqCtx = queue.front();
328 
329     std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()};
330 
331     std::optional<uint8_t> iid = getNextIid(eid);
332     if (!iid)
333     {
334         lg2::error("MctpRequester: Unable to get next iid");
335         handleResult(eid, std::make_error_code(std::errc::no_such_device), {});
336         return;
337     }
338 
339     std::expected<void, std::error_code> success = injectIid(req, *iid);
340     if (!success)
341     {
342         lg2::error("MctpRequester: unable to set iid");
343         handleResult(eid, success.error(), {});
344         return;
345     }
346 
347     struct sockaddr_mctp addr{};
348     addr.smctp_family = AF_MCTP;
349     addr.smctp_addr.s_addr = eid;
350     addr.smctp_type = msgType;
351     addr.smctp_tag = MCTP_TAG_OWNER;
352     using endpoint = boost::asio::generic::datagram_protocol::endpoint;
353     endpoint sendEndPoint{&addr, sizeof(addr)};
354 
355     mctpSocket.async_send_to(
356         boost::asio::const_buffer(req.data(), req.size()), sendEndPoint,
357         std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid));
358 }
359 
360 } // namespace mctp
361