xref: /openbmc/dbus-sensors/src/nvidia-gpu/MctpRequester.cpp (revision 1e47dcf2577670a56f7552261f35492bdcfae637)
1560e6af7SHarshit Aghera /*
2b5e823f7SEd Tanous  * SPDX-FileCopyrightText: Copyright OpenBMC Authors
3560e6af7SHarshit Aghera  * SPDX-License-Identifier: Apache-2.0
4560e6af7SHarshit Aghera  */
5560e6af7SHarshit Aghera 
6560e6af7SHarshit Aghera #include "MctpRequester.hpp"
7560e6af7SHarshit Aghera 
86ef89739SEd Tanous #include "MctpAsioEndpoint.hpp"
96ef89739SEd Tanous 
10560e6af7SHarshit Aghera #include <sys/socket.h>
11560e6af7SHarshit Aghera 
12560e6af7SHarshit Aghera #include <OcpMctpVdm.hpp>
13560e6af7SHarshit Aghera #include <boost/asio/buffer.hpp>
14560e6af7SHarshit Aghera #include <boost/asio/error.hpp>
15560e6af7SHarshit Aghera #include <boost/asio/generic/datagram_protocol.hpp>
16560e6af7SHarshit Aghera #include <boost/asio/io_context.hpp>
17560e6af7SHarshit Aghera #include <boost/asio/steady_timer.hpp>
18ed0af21cSAditya Kurdunkar #include <boost/container/devector.hpp>
19560e6af7SHarshit Aghera #include <phosphor-logging/lg2.hpp>
20560e6af7SHarshit Aghera 
21d0125c9cSMarc Olberding #include <bit>
22560e6af7SHarshit Aghera #include <cstddef>
23560e6af7SHarshit Aghera #include <cstdint>
24560e6af7SHarshit Aghera #include <cstring>
25d0125c9cSMarc Olberding #include <expected>
26d0125c9cSMarc Olberding #include <format>
27560e6af7SHarshit Aghera #include <functional>
28d0125c9cSMarc Olberding #include <optional>
29560e6af7SHarshit Aghera #include <span>
30d0125c9cSMarc Olberding #include <stdexcept>
31d0125c9cSMarc Olberding #include <system_error>
32560e6af7SHarshit Aghera #include <utility>
33560e6af7SHarshit Aghera 
34560e6af7SHarshit Aghera using namespace std::literals;
35560e6af7SHarshit Aghera 
36560e6af7SHarshit Aghera namespace mctp
37560e6af7SHarshit Aghera {
38560e6af7SHarshit Aghera 
getHeaderFromBuffer(std::span<const uint8_t> buffer)39d0125c9cSMarc Olberding static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer(
40d0125c9cSMarc Olberding     std::span<const uint8_t> buffer)
41560e6af7SHarshit Aghera {
42d0125c9cSMarc Olberding     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
43ed0af21cSAditya Kurdunkar     {
44d0125c9cSMarc Olberding         return nullptr;
45d0125c9cSMarc Olberding     }
46d0125c9cSMarc Olberding 
47d0125c9cSMarc Olberding     return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>(
48d0125c9cSMarc Olberding         buffer.data());
49d0125c9cSMarc Olberding }
50d0125c9cSMarc Olberding 
getIid(std::span<const uint8_t> buffer)51d0125c9cSMarc Olberding static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer)
52d0125c9cSMarc Olberding {
53d0125c9cSMarc Olberding     const ocp::accelerator_management::BindingPciVid* header =
54d0125c9cSMarc Olberding         getHeaderFromBuffer(buffer);
55d0125c9cSMarc Olberding     if (header == nullptr)
56d0125c9cSMarc Olberding     {
57d0125c9cSMarc Olberding         return std::nullopt;
58d0125c9cSMarc Olberding     }
59d0125c9cSMarc Olberding     return header->instance_id & ocp::accelerator_management::instanceIdBitMask;
60d0125c9cSMarc Olberding }
61d0125c9cSMarc Olberding 
getRequestBit(std::span<const uint8_t> buffer)62d0125c9cSMarc Olberding static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer)
63d0125c9cSMarc Olberding {
64d0125c9cSMarc Olberding     const ocp::accelerator_management::BindingPciVid* header =
65d0125c9cSMarc Olberding         getHeaderFromBuffer(buffer);
66d0125c9cSMarc Olberding     if (header == nullptr)
67d0125c9cSMarc Olberding     {
68d0125c9cSMarc Olberding         return std::nullopt;
69d0125c9cSMarc Olberding     }
70d0125c9cSMarc Olberding     return header->instance_id & ocp::accelerator_management::requestBitMask;
71d0125c9cSMarc Olberding }
72d0125c9cSMarc Olberding 
MctpRequester(boost::asio::io_context & ctx)73d0125c9cSMarc Olberding MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
74d0125c9cSMarc Olberding     io{ctx},
75d0125c9cSMarc Olberding     mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0})
76d0125c9cSMarc Olberding {
77d0125c9cSMarc Olberding     startReceive();
78d0125c9cSMarc Olberding }
79d0125c9cSMarc Olberding 
startReceive()80d0125c9cSMarc Olberding void MctpRequester::startReceive()
81d0125c9cSMarc Olberding {
82d0125c9cSMarc Olberding     mctpSocket.async_receive_from(
83d0125c9cSMarc Olberding         boost::asio::buffer(buffer), recvEndPoint.endpoint,
84d0125c9cSMarc Olberding         std::bind_front(&MctpRequester::processRecvMsg, this));
85d0125c9cSMarc Olberding }
86d0125c9cSMarc Olberding 
processRecvMsg(const boost::system::error_code & ec,const size_t length)87d0125c9cSMarc Olberding void MctpRequester::processRecvMsg(const boost::system::error_code& ec,
88d0125c9cSMarc Olberding                                    const size_t length)
89d0125c9cSMarc Olberding {
90d0125c9cSMarc Olberding     std::optional<uint8_t> expectedEid = recvEndPoint.eid();
91d0125c9cSMarc Olberding     std::optional<uint8_t> receivedMsgType = recvEndPoint.type();
92d0125c9cSMarc Olberding 
93d0125c9cSMarc Olberding     if (!expectedEid || !receivedMsgType)
94d0125c9cSMarc Olberding     {
95d0125c9cSMarc Olberding         // we were handed an endpoint that can't be treated as an MCTP endpoint
96d0125c9cSMarc Olberding         // This is probably a kernel bug...yell about it and rebind.
97d0125c9cSMarc Olberding         lg2::error("MctpRequester: invalid endpoint");
98ed0af21cSAditya Kurdunkar         return;
99ed0af21cSAditya Kurdunkar     }
100ed0af21cSAditya Kurdunkar 
1016ef89739SEd Tanous     if (*receivedMsgType != ocp::accelerator_management::messageType)
102ed0af21cSAditya Kurdunkar     {
103d0125c9cSMarc Olberding         // we received a message that this handler doesn't support
104d0125c9cSMarc Olberding         // drop it on the floor and rebind receive_from
105d0125c9cSMarc Olberding         lg2::error("MctpRequester: Message type mismatch. We received {MSG}",
106d0125c9cSMarc Olberding                    "MSG", *receivedMsgType);
107ed0af21cSAditya Kurdunkar         return;
108ed0af21cSAditya Kurdunkar     }
109ed0af21cSAditya Kurdunkar 
110d0125c9cSMarc Olberding     uint8_t eid = *expectedEid;
111560e6af7SHarshit Aghera 
112560e6af7SHarshit Aghera     if (ec)
113560e6af7SHarshit Aghera     {
114560e6af7SHarshit Aghera         lg2::error(
115560e6af7SHarshit Aghera             "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
116560e6af7SHarshit Aghera             "EC", ec.value(), "ER", ec.message());
117d0125c9cSMarc Olberding         handleResult(eid, static_cast<std::error_code>(ec), {});
118560e6af7SHarshit Aghera         return;
119560e6af7SHarshit Aghera     }
120560e6af7SHarshit Aghera 
121d0125c9cSMarc Olberding     // if the received length was greater than our buffer, we would've truncated
122d0125c9cSMarc Olberding     // and gotten an error code in asio
123d0125c9cSMarc Olberding     std::span<const uint8_t> responseBuffer{buffer.data(), length};
124d0125c9cSMarc Olberding 
125d0125c9cSMarc Olberding     std::optional<uint8_t> optionalIid = getIid(responseBuffer);
126d0125c9cSMarc Olberding     std::optional<bool> isRq = getRequestBit(responseBuffer);
127d0125c9cSMarc Olberding     if (!optionalIid || !isRq)
128560e6af7SHarshit Aghera     {
129d0125c9cSMarc Olberding         // we received something from the device,
130d0125c9cSMarc Olberding         // but we aren't able to parse iid byte
131d0125c9cSMarc Olberding         // drop this packet on the floor
132d0125c9cSMarc Olberding         // and rely on the timer to notify the client
133d0125c9cSMarc Olberding         lg2::error("MctpRequester: Unable to decode message from eid {EID}",
134d0125c9cSMarc Olberding                    "EID", eid);
135ed0af21cSAditya Kurdunkar         return;
136ed0af21cSAditya Kurdunkar     }
137ed0af21cSAditya Kurdunkar 
138d0125c9cSMarc Olberding     if (isRq.value())
139d0125c9cSMarc Olberding     {
140d0125c9cSMarc Olberding         // we received a request from a downstream device.
141d0125c9cSMarc Olberding         // We don't currently support this, drop the packet
142d0125c9cSMarc Olberding         // on the floor and rebind receive, keep the timer running
143d0125c9cSMarc Olberding         return;
144d0125c9cSMarc Olberding     }
145ed0af21cSAditya Kurdunkar 
146d0125c9cSMarc Olberding     uint8_t iid = *optionalIid;
147d0125c9cSMarc Olberding 
148d0125c9cSMarc Olberding     auto it = requestContextQueues.find(eid);
149d0125c9cSMarc Olberding     if (it == requestContextQueues.end())
150d0125c9cSMarc Olberding     {
151d0125c9cSMarc Olberding         // something very bad has happened here
152d0125c9cSMarc Olberding         // we've received a packet that is a response
153d0125c9cSMarc Olberding         // from a device we've never talked to
154d0125c9cSMarc Olberding         // do our best and rebind receive and keep the timer running
155d0125c9cSMarc Olberding         lg2::error("Unable to match request to response");
156d0125c9cSMarc Olberding         return;
157d0125c9cSMarc Olberding     }
158d0125c9cSMarc Olberding 
159d0125c9cSMarc Olberding     if (iid != it->second.iid)
160d0125c9cSMarc Olberding     {
161d0125c9cSMarc Olberding         // we received an iid that doesn't match the one we sent
162d0125c9cSMarc Olberding         // rebind async_receive_from and drop this packet on the floor
163d0125c9cSMarc Olberding         lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID",
164d0125c9cSMarc Olberding                    iid, "EID", eid, "E_IID", it->second.iid);
165d0125c9cSMarc Olberding         return;
166d0125c9cSMarc Olberding     }
167d0125c9cSMarc Olberding 
168d0125c9cSMarc Olberding     handleResult(eid, std::error_code{}, responseBuffer);
169d0125c9cSMarc Olberding }
170d0125c9cSMarc Olberding 
handleSendMsgCompletion(uint8_t eid,const boost::system::error_code & ec,size_t)171d0125c9cSMarc Olberding void MctpRequester::handleSendMsgCompletion(
172d0125c9cSMarc Olberding     uint8_t eid, const boost::system::error_code& ec, size_t /* length */)
173d0125c9cSMarc Olberding {
174560e6af7SHarshit Aghera     if (ec)
175560e6af7SHarshit Aghera     {
176560e6af7SHarshit Aghera         lg2::error(
177560e6af7SHarshit Aghera             "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
178560e6af7SHarshit Aghera             "EC", ec.value(), "ER", ec.message());
179d0125c9cSMarc Olberding         handleResult(eid, static_cast<std::error_code>(ec), {});
180560e6af7SHarshit Aghera         return;
181560e6af7SHarshit Aghera     }
182560e6af7SHarshit Aghera 
183d0125c9cSMarc Olberding     auto it = requestContextQueues.find(eid);
184d0125c9cSMarc Olberding     if (it == requestContextQueues.end())
185d0125c9cSMarc Olberding     {
186d0125c9cSMarc Olberding         // something very bad has happened here,
187d0125c9cSMarc Olberding         // we've sent something to a device that we have
188d0125c9cSMarc Olberding         // no record of. yell loudly and bail
189d0125c9cSMarc Olberding         lg2::error(
190d0125c9cSMarc Olberding             "MctpRequester completed send for an EID that we have no record of");
191d0125c9cSMarc Olberding         return;
192d0125c9cSMarc Olberding     }
193d0125c9cSMarc Olberding 
194d0125c9cSMarc Olberding     boost::asio::steady_timer& expiryTimer = it->second.timer;
195560e6af7SHarshit Aghera     expiryTimer.expires_after(2s);
196560e6af7SHarshit Aghera 
197ed0af21cSAditya Kurdunkar     expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
198560e6af7SHarshit Aghera         if (ec != boost::asio::error::operation_aborted)
199560e6af7SHarshit Aghera         {
200d0125c9cSMarc Olberding             lg2::error("Operation timed out on eid {EID}", "EID", eid);
201d0125c9cSMarc Olberding             handleResult(eid, std::make_error_code(std::errc::timed_out), {});
202560e6af7SHarshit Aghera         }
203560e6af7SHarshit Aghera     });
204560e6af7SHarshit Aghera }
205560e6af7SHarshit Aghera 
sendRecvMsg(uint8_t eid,std::span<const uint8_t> reqMsg,std::move_only_function<void (const std::error_code &,std::span<const uint8_t>)> callback)206d0125c9cSMarc Olberding void MctpRequester::sendRecvMsg(
207d0125c9cSMarc Olberding     uint8_t eid, std::span<const uint8_t> reqMsg,
208d0125c9cSMarc Olberding     std::move_only_function<void(const std::error_code&,
209d0125c9cSMarc Olberding                                  std::span<const uint8_t>)>
210d0125c9cSMarc Olberding         callback)
211560e6af7SHarshit Aghera {
212d0125c9cSMarc Olberding     RequestContext reqCtx{reqMsg, std::move(callback)};
213d0125c9cSMarc Olberding 
214d0125c9cSMarc Olberding     // try_emplace only affects the result if the key does not already exist
215d0125c9cSMarc Olberding     auto [it, inserted] = requestContextQueues.try_emplace(eid, io);
216d0125c9cSMarc Olberding     (void)inserted;
217d0125c9cSMarc Olberding 
218d0125c9cSMarc Olberding     auto& queue = it->second.queue;
219d0125c9cSMarc Olberding     queue.push_back(std::move(reqCtx));
220d0125c9cSMarc Olberding 
221d0125c9cSMarc Olberding     if (queue.size() == 1)
222560e6af7SHarshit Aghera     {
223d0125c9cSMarc Olberding         processQueue(eid);
224d0125c9cSMarc Olberding     }
225d0125c9cSMarc Olberding }
226d0125c9cSMarc Olberding 
isFatalError(const std::error_code & ec)227d0125c9cSMarc Olberding static bool isFatalError(const std::error_code& ec)
228d0125c9cSMarc Olberding {
229d0125c9cSMarc Olberding     return ec &&
230d0125c9cSMarc Olberding            (ec != std::errc::timed_out && ec != std::errc::host_unreachable);
231d0125c9cSMarc Olberding }
232d0125c9cSMarc Olberding 
handleResult(uint8_t eid,const std::error_code & ec,std::span<const uint8_t> buffer)233d0125c9cSMarc Olberding void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec,
234d0125c9cSMarc Olberding                                  std::span<const uint8_t> buffer)
235d0125c9cSMarc Olberding {
236d0125c9cSMarc Olberding     auto it = requestContextQueues.find(eid);
237d0125c9cSMarc Olberding     if (it == requestContextQueues.end())
238d0125c9cSMarc Olberding     {
239d0125c9cSMarc Olberding         lg2::error("We tried to a handle a result for an eid we don't have");
240d0125c9cSMarc Olberding 
241d0125c9cSMarc Olberding         startReceive();
242560e6af7SHarshit Aghera         return;
243560e6af7SHarshit Aghera     }
244560e6af7SHarshit Aghera 
245d0125c9cSMarc Olberding     auto& queue = it->second.queue;
246d0125c9cSMarc Olberding     auto& reqCtx = queue.front();
247d0125c9cSMarc Olberding 
248d0125c9cSMarc Olberding     it->second.timer.cancel();
249d0125c9cSMarc Olberding 
250d0125c9cSMarc Olberding     reqCtx.callback(ec, buffer); // Call the original callback
251d0125c9cSMarc Olberding 
252d0125c9cSMarc Olberding     if (isFatalError(ec))
253d0125c9cSMarc Olberding     {
254d0125c9cSMarc Olberding         // some errors are fatal, since these are datagrams,
255d0125c9cSMarc Olberding         // we won't get a receive path error message.
256d0125c9cSMarc Olberding         // and since this daemon services all nvidia iana commands
257d0125c9cSMarc Olberding         // for a given system, we should only restart the service if its
258d0125c9cSMarc Olberding         // unrecoverable, i.e. if we get error codes that the client
259*1e47dcf2SGeorge Liu         // can't reasonably deal with. If that's the cause, restart
260d0125c9cSMarc Olberding         // and hope that we can deal with it then.
261d0125c9cSMarc Olberding         // since we're fully async, the only reasonable way to bubble
262d0125c9cSMarc Olberding         // this issue up is to chuck an exception and let main deal with it.
263d0125c9cSMarc Olberding         // alternatively we could call cancel on the io_context, but there's
264d0125c9cSMarc Olberding         // not a great way to figure *what* happened.
265d0125c9cSMarc Olberding         throw std::runtime_error(std::format(
266d0125c9cSMarc Olberding             "eid {} encountered a fatal error: {}", eid, ec.message()));
267d0125c9cSMarc Olberding     }
268d0125c9cSMarc Olberding 
269d0125c9cSMarc Olberding     startReceive();
270d0125c9cSMarc Olberding 
271d0125c9cSMarc Olberding     queue.pop_front();
272d0125c9cSMarc Olberding 
273d0125c9cSMarc Olberding     processQueue(eid);
274d0125c9cSMarc Olberding }
275d0125c9cSMarc Olberding 
getNextIid(uint8_t eid)276d0125c9cSMarc Olberding std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid)
277d0125c9cSMarc Olberding {
278d0125c9cSMarc Olberding     auto it = requestContextQueues.find(eid);
279d0125c9cSMarc Olberding     if (it == requestContextQueues.end())
280d0125c9cSMarc Olberding     {
281d0125c9cSMarc Olberding         return std::nullopt;
282d0125c9cSMarc Olberding     }
283d0125c9cSMarc Olberding 
284d0125c9cSMarc Olberding     uint8_t& iid = it->second.iid;
285d0125c9cSMarc Olberding     ++iid;
286d0125c9cSMarc Olberding     iid &= ocp::accelerator_management::instanceIdBitMask;
287d0125c9cSMarc Olberding     return iid;
288d0125c9cSMarc Olberding }
289d0125c9cSMarc Olberding 
injectIid(std::span<uint8_t> buffer,uint8_t iid)290d0125c9cSMarc Olberding static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer,
291d0125c9cSMarc Olberding                                                       uint8_t iid)
292d0125c9cSMarc Olberding {
293d0125c9cSMarc Olberding     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
294d0125c9cSMarc Olberding     {
295d0125c9cSMarc Olberding         return std::unexpected(
296d0125c9cSMarc Olberding             std::make_error_code(std::errc::invalid_argument));
297d0125c9cSMarc Olberding     }
298d0125c9cSMarc Olberding 
299d0125c9cSMarc Olberding     if (iid > ocp::accelerator_management::instanceIdBitMask)
300d0125c9cSMarc Olberding     {
301d0125c9cSMarc Olberding         return std::unexpected(
302d0125c9cSMarc Olberding             std::make_error_code(std::errc::invalid_argument));
303d0125c9cSMarc Olberding     }
304d0125c9cSMarc Olberding 
305d0125c9cSMarc Olberding     auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>(
306d0125c9cSMarc Olberding         buffer.data());
307d0125c9cSMarc Olberding 
308d0125c9cSMarc Olberding     header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask;
309d0125c9cSMarc Olberding     header->instance_id |= iid;
310d0125c9cSMarc Olberding     return {};
311d0125c9cSMarc Olberding }
312d0125c9cSMarc Olberding 
processQueue(uint8_t eid)313d0125c9cSMarc Olberding void MctpRequester::processQueue(uint8_t eid)
314d0125c9cSMarc Olberding {
315d0125c9cSMarc Olberding     auto it = requestContextQueues.find(eid);
316d0125c9cSMarc Olberding     if (it == requestContextQueues.end())
317d0125c9cSMarc Olberding     {
318d0125c9cSMarc Olberding         lg2::error("We are attempting to process a queue that doesn't exist");
319d0125c9cSMarc Olberding         return;
320d0125c9cSMarc Olberding     }
321d0125c9cSMarc Olberding 
322d0125c9cSMarc Olberding     auto& queue = it->second.queue;
323d0125c9cSMarc Olberding 
324d0125c9cSMarc Olberding     if (queue.empty())
325d0125c9cSMarc Olberding     {
326d0125c9cSMarc Olberding         return;
327d0125c9cSMarc Olberding     }
328d0125c9cSMarc Olberding     auto& reqCtx = queue.front();
329d0125c9cSMarc Olberding 
330d0125c9cSMarc Olberding     std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()};
331d0125c9cSMarc Olberding 
332d0125c9cSMarc Olberding     std::optional<uint8_t> iid = getNextIid(eid);
333d0125c9cSMarc Olberding     if (!iid)
334d0125c9cSMarc Olberding     {
335d0125c9cSMarc Olberding         lg2::error("MctpRequester: Unable to get next iid");
336d0125c9cSMarc Olberding         handleResult(eid, std::make_error_code(std::errc::no_such_device), {});
337d0125c9cSMarc Olberding         return;
338d0125c9cSMarc Olberding     }
339d0125c9cSMarc Olberding 
340d0125c9cSMarc Olberding     std::expected<void, std::error_code> success = injectIid(req, *iid);
341d0125c9cSMarc Olberding     if (!success)
342d0125c9cSMarc Olberding     {
343d0125c9cSMarc Olberding         lg2::error("MctpRequester: unable to set iid");
344d0125c9cSMarc Olberding         handleResult(eid, success.error(), {});
345d0125c9cSMarc Olberding         return;
346d0125c9cSMarc Olberding     }
347560e6af7SHarshit Aghera 
3486ef89739SEd Tanous     MctpAsioEndpoint sendEndPoint(eid,
3496ef89739SEd Tanous                                   ocp::accelerator_management::messageType);
3506ef89739SEd Tanous     boost::asio::const_buffer buf(req.data(), req.size());
351560e6af7SHarshit Aghera     mctpSocket.async_send_to(
3526ef89739SEd Tanous         buf, sendEndPoint.endpoint,
353d0125c9cSMarc Olberding         std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid));
354ed0af21cSAditya Kurdunkar }
355ed0af21cSAditya Kurdunkar 
356560e6af7SHarshit Aghera } // namespace mctp
357