1 /*
2 * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION &
3 * AFFILIATES. All rights reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include "MctpRequester.hpp"
8
9 #include <linux/mctp.h>
10 #include <sys/socket.h>
11
12 #include <OcpMctpVdm.hpp>
13 #include <boost/asio/buffer.hpp>
14 #include <boost/asio/error.hpp>
15 #include <boost/asio/generic/datagram_protocol.hpp>
16 #include <boost/asio/io_context.hpp>
17 #include <boost/asio/steady_timer.hpp>
18 #include <boost/container/devector.hpp>
19 #include <phosphor-logging/lg2.hpp>
20
21 #include <cerrno>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <functional>
26 #include <memory>
27 #include <span>
28 #include <utility>
29
30 using namespace std::literals;
31
32 namespace mctp
33 {
34
Requester(boost::asio::io_context & ctx)35 Requester::Requester(boost::asio::io_context& ctx) :
36 mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}),
37 expiryTimer(ctx)
38 {}
39
processRecvMsg(const std::span<const uint8_t> reqMsg,const std::span<uint8_t> respMsg,const boost::system::error_code & ec,const size_t)40 void Requester::processRecvMsg(
41 const std::span<const uint8_t> reqMsg, const std::span<uint8_t> respMsg,
42 const boost::system::error_code& ec, const size_t /*length*/)
43 {
44 const auto* respAddr =
45 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
46 reinterpret_cast<const struct sockaddr_mctp*>(recvEndPoint.data());
47
48 uint8_t eid = respAddr->smctp_addr.s_addr;
49
50 if (!completionCallbacks.contains(eid))
51 {
52 lg2::error(
53 "MctpRequester failed to get the callback for the EID: {EID}",
54 "EID", static_cast<int>(eid));
55 return;
56 }
57
58 auto& callback = completionCallbacks.at(eid);
59
60 if (respAddr->smctp_type != msgType)
61 {
62 lg2::error("MctpRequester: Message type mismatch");
63 callback(EPROTO);
64 return;
65 }
66
67 expiryTimer.cancel();
68
69 if (ec)
70 {
71 lg2::error(
72 "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
73 "EC", ec.value(), "ER", ec.message());
74 callback(EIO);
75 return;
76 }
77
78 if (respMsg.size() > sizeof(ocp::accelerator_management::BindingPciVid))
79 {
80 const auto* reqHdr =
81 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
82 reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
83 reqMsg.data());
84
85 uint8_t reqInstanceId = reqHdr->instance_id &
86 ocp::accelerator_management::instanceIdBitMask;
87 const auto* respHdr =
88 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
89 reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
90 respMsg.data());
91
92 uint8_t respInstanceId = respHdr->instance_id &
93 ocp::accelerator_management::instanceIdBitMask;
94
95 if (reqInstanceId != respInstanceId)
96 {
97 lg2::error(
98 "MctpRequester: Instance ID mismatch - request={REQ}, response={RESP}",
99 "REQ", static_cast<int>(reqInstanceId), "RESP",
100 static_cast<int>(respInstanceId));
101 callback(EPROTO);
102 return;
103 }
104 }
105
106 callback(0);
107 }
108
handleSendMsgCompletion(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,const boost::system::error_code & ec,size_t)109 void Requester::handleSendMsgCompletion(
110 uint8_t eid, const std::span<const uint8_t> reqMsg,
111 std::span<uint8_t> respMsg, const boost::system::error_code& ec,
112 size_t /* length */)
113 {
114 if (!completionCallbacks.contains(eid))
115 {
116 lg2::error(
117 "MctpRequester failed to get the callback for the EID: {EID}",
118 "EID", static_cast<int>(eid));
119 return;
120 }
121
122 auto& callback = completionCallbacks.at(eid);
123
124 if (ec)
125 {
126 lg2::error(
127 "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
128 "EC", ec.value(), "ER", ec.message());
129 callback(EIO);
130 return;
131 }
132
133 expiryTimer.expires_after(2s);
134
135 expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
136 if (ec != boost::asio::error::operation_aborted)
137 {
138 auto& callback = completionCallbacks.at(eid);
139 callback(ETIME);
140 }
141 });
142
143 mctpSocket.async_receive_from(
144 boost::asio::mutable_buffer(respMsg), recvEndPoint,
145 std::bind_front(&Requester::processRecvMsg, this, reqMsg, respMsg));
146 }
147
sendRecvMsg(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,std::move_only_function<void (int)> callback)148 void Requester::sendRecvMsg(uint8_t eid, const std::span<const uint8_t> reqMsg,
149 std::span<uint8_t> respMsg,
150 std::move_only_function<void(int)> callback)
151 {
152 if (reqMsg.size() < sizeof(ocp::accelerator_management::BindingPciVid))
153 {
154 lg2::error("MctpRequester: Message too small");
155 callback(EPROTO);
156 return;
157 }
158
159 completionCallbacks[eid] = std::move(callback);
160
161 struct sockaddr_mctp addr{};
162 addr.smctp_family = AF_MCTP;
163 addr.smctp_addr.s_addr = eid;
164 addr.smctp_type = msgType;
165 addr.smctp_tag = MCTP_TAG_OWNER;
166
167 sendEndPoint = {&addr, sizeof(addr)};
168
169 mctpSocket.async_send_to(
170 boost::asio::const_buffer(reqMsg), sendEndPoint,
171 std::bind_front(&Requester::handleSendMsgCompletion, this, eid, reqMsg,
172 respMsg));
173 }
174
sendRecvMsg(uint8_t eid,std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,std::move_only_function<void (int)> callback)175 void QueuingRequester::sendRecvMsg(uint8_t eid, std::span<const uint8_t> reqMsg,
176 std::span<uint8_t> respMsg,
177 std::move_only_function<void(int)> callback)
178 {
179 auto reqCtx =
180 std::make_unique<RequestContext>(reqMsg, respMsg, std::move(callback));
181
182 // Add request to queue
183 auto& queue = requestContextQueues[eid];
184 queue.push_back(std::move(reqCtx));
185
186 if (queue.size() == 1)
187 {
188 processQueue(eid);
189 }
190 }
191
handleResult(uint8_t eid,int result)192 void QueuingRequester::handleResult(uint8_t eid, int result)
193 {
194 auto& queue = requestContextQueues[eid];
195 const auto& reqCtx = queue.front();
196
197 reqCtx->callback(result); // Call the original callback
198
199 queue.pop_front();
200
201 processQueue(eid);
202 }
203
processQueue(uint8_t eid)204 void QueuingRequester::processQueue(uint8_t eid)
205 {
206 auto& queue = requestContextQueues[eid];
207
208 if (queue.empty())
209 {
210 return;
211 }
212
213 const auto& reqCtx = queue.front();
214
215 requester.sendRecvMsg(
216 eid, reqCtx->reqMsg, reqCtx->respMsg,
217 std::bind_front(&QueuingRequester::handleResult, this, eid));
218 }
219
220 } // namespace mctp
221