xref: /openbmc/dbus-sensors/src/nvidia-gpu/MctpRequester.cpp (revision ed0af21ca6092a7ceb19660b69b76a5c304efcb0)
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION &
3  * AFFILIATES. All rights reserved.
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "MctpRequester.hpp"
8 
9 #include <linux/mctp.h>
10 #include <sys/socket.h>
11 
12 #include <OcpMctpVdm.hpp>
13 #include <boost/asio/buffer.hpp>
14 #include <boost/asio/error.hpp>
15 #include <boost/asio/generic/datagram_protocol.hpp>
16 #include <boost/asio/io_context.hpp>
17 #include <boost/asio/steady_timer.hpp>
18 #include <boost/container/devector.hpp>
19 #include <phosphor-logging/lg2.hpp>
20 
21 #include <cerrno>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <functional>
26 #include <memory>
27 #include <span>
28 #include <utility>
29 
30 using namespace std::literals;
31 
32 namespace mctp
33 {
34 
Requester(boost::asio::io_context & ctx)35 Requester::Requester(boost::asio::io_context& ctx) :
36     mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}),
37     expiryTimer(ctx)
38 {}
39 
processRecvMsg(const std::span<const uint8_t> reqMsg,const std::span<uint8_t> respMsg,const boost::system::error_code & ec,const size_t)40 void Requester::processRecvMsg(
41     const std::span<const uint8_t> reqMsg, const std::span<uint8_t> respMsg,
42     const boost::system::error_code& ec, const size_t /*length*/)
43 {
44     const auto* respAddr =
45         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
46         reinterpret_cast<const struct sockaddr_mctp*>(recvEndPoint.data());
47 
48     uint8_t eid = respAddr->smctp_addr.s_addr;
49 
50     if (!completionCallbacks.contains(eid))
51     {
52         lg2::error(
53             "MctpRequester failed to get the callback for the EID: {EID}",
54             "EID", static_cast<int>(eid));
55         return;
56     }
57 
58     auto& callback = completionCallbacks.at(eid);
59 
60     if (respAddr->smctp_type != msgType)
61     {
62         lg2::error("MctpRequester: Message type mismatch");
63         callback(EPROTO);
64         return;
65     }
66 
67     expiryTimer.cancel();
68 
69     if (ec)
70     {
71         lg2::error(
72             "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
73             "EC", ec.value(), "ER", ec.message());
74         callback(EIO);
75         return;
76     }
77 
78     if (respMsg.size() > sizeof(ocp::accelerator_management::BindingPciVid))
79     {
80         const auto* reqHdr =
81             // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
82             reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
83                 reqMsg.data());
84 
85         uint8_t reqInstanceId = reqHdr->instance_id &
86                                 ocp::accelerator_management::instanceIdBitMask;
87         const auto* respHdr =
88             // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
89             reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
90                 respMsg.data());
91 
92         uint8_t respInstanceId = respHdr->instance_id &
93                                  ocp::accelerator_management::instanceIdBitMask;
94 
95         if (reqInstanceId != respInstanceId)
96         {
97             lg2::error(
98                 "MctpRequester: Instance ID mismatch - request={REQ}, response={RESP}",
99                 "REQ", static_cast<int>(reqInstanceId), "RESP",
100                 static_cast<int>(respInstanceId));
101             callback(EPROTO);
102             return;
103         }
104     }
105 
106     callback(0);
107 }
108 
handleSendMsgCompletion(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,const boost::system::error_code & ec,size_t)109 void Requester::handleSendMsgCompletion(
110     uint8_t eid, const std::span<const uint8_t> reqMsg,
111     std::span<uint8_t> respMsg, const boost::system::error_code& ec,
112     size_t /* length */)
113 {
114     if (!completionCallbacks.contains(eid))
115     {
116         lg2::error(
117             "MctpRequester failed to get the callback for the EID: {EID}",
118             "EID", static_cast<int>(eid));
119         return;
120     }
121 
122     auto& callback = completionCallbacks.at(eid);
123 
124     if (ec)
125     {
126         lg2::error(
127             "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
128             "EC", ec.value(), "ER", ec.message());
129         callback(EIO);
130         return;
131     }
132 
133     expiryTimer.expires_after(2s);
134 
135     expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
136         if (ec != boost::asio::error::operation_aborted)
137         {
138             auto& callback = completionCallbacks.at(eid);
139             callback(ETIME);
140         }
141     });
142 
143     mctpSocket.async_receive_from(
144         boost::asio::mutable_buffer(respMsg), recvEndPoint,
145         std::bind_front(&Requester::processRecvMsg, this, reqMsg, respMsg));
146 }
147 
sendRecvMsg(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,std::move_only_function<void (int)> callback)148 void Requester::sendRecvMsg(uint8_t eid, const std::span<const uint8_t> reqMsg,
149                             std::span<uint8_t> respMsg,
150                             std::move_only_function<void(int)> callback)
151 {
152     if (reqMsg.size() < sizeof(ocp::accelerator_management::BindingPciVid))
153     {
154         lg2::error("MctpRequester: Message too small");
155         callback(EPROTO);
156         return;
157     }
158 
159     completionCallbacks[eid] = std::move(callback);
160 
161     struct sockaddr_mctp addr{};
162     addr.smctp_family = AF_MCTP;
163     addr.smctp_addr.s_addr = eid;
164     addr.smctp_type = msgType;
165     addr.smctp_tag = MCTP_TAG_OWNER;
166 
167     sendEndPoint = {&addr, sizeof(addr)};
168 
169     mctpSocket.async_send_to(
170         boost::asio::const_buffer(reqMsg), sendEndPoint,
171         std::bind_front(&Requester::handleSendMsgCompletion, this, eid, reqMsg,
172                         respMsg));
173 }
174 
sendRecvMsg(uint8_t eid,std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,std::move_only_function<void (int)> callback)175 void QueuingRequester::sendRecvMsg(uint8_t eid, std::span<const uint8_t> reqMsg,
176                                    std::span<uint8_t> respMsg,
177                                    std::move_only_function<void(int)> callback)
178 {
179     auto reqCtx =
180         std::make_unique<RequestContext>(reqMsg, respMsg, std::move(callback));
181 
182     // Add request to queue
183     auto& queue = requestContextQueues[eid];
184     queue.push_back(std::move(reqCtx));
185 
186     if (queue.size() == 1)
187     {
188         processQueue(eid);
189     }
190 }
191 
handleResult(uint8_t eid,int result)192 void QueuingRequester::handleResult(uint8_t eid, int result)
193 {
194     auto& queue = requestContextQueues[eid];
195     const auto& reqCtx = queue.front();
196 
197     reqCtx->callback(result); // Call the original callback
198 
199     queue.pop_front();
200 
201     processQueue(eid);
202 }
203 
processQueue(uint8_t eid)204 void QueuingRequester::processQueue(uint8_t eid)
205 {
206     auto& queue = requestContextQueues[eid];
207 
208     if (queue.empty())
209     {
210         return;
211     }
212 
213     const auto& reqCtx = queue.front();
214 
215     requester.sendRecvMsg(
216         eid, reqCtx->reqMsg, reqCtx->respMsg,
217         std::bind_front(&QueuingRequester::handleResult, this, eid));
218 }
219 
220 } // namespace mctp
221