1 /*
2 * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION &
3 * AFFILIATES. All rights reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include "MctpRequester.hpp"
8
9 #include <linux/mctp.h>
10 #include <sys/socket.h>
11
12 #include <OcpMctpVdm.hpp>
13 #include <boost/asio/buffer.hpp>
14 #include <boost/asio/error.hpp>
15 #include <boost/asio/generic/datagram_protocol.hpp>
16 #include <boost/asio/io_context.hpp>
17 #include <boost/asio/steady_timer.hpp>
18 #include <phosphor-logging/lg2.hpp>
19
20 #include <cerrno>
21 #include <cstddef>
22 #include <cstdint>
23 #include <cstring>
24 #include <functional>
25 #include <span>
26 #include <utility>
27
28 using namespace std::literals;
29
30 namespace mctp
31 {
32
MctpRequester(boost::asio::io_context & ctx)33 MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
34 mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}),
35 expiryTimer(ctx)
36 {}
37
processRecvMsg(uint8_t eid,const std::span<const uint8_t> reqMsg,const std::span<uint8_t> respMsg,const boost::system::error_code & ec,const size_t)38 void MctpRequester::processRecvMsg(
39 uint8_t eid, const std::span<const uint8_t> reqMsg,
40 const std::span<uint8_t> respMsg, const boost::system::error_code& ec,
41 const size_t /*length*/)
42 {
43 expiryTimer.cancel();
44
45 if (ec)
46 {
47 lg2::error(
48 "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
49 "EC", ec.value(), "ER", ec.message());
50 completionCallback(EIO);
51 return;
52 }
53
54 const auto* respAddr =
55 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
56 reinterpret_cast<const struct sockaddr_mctp*>(recvEndPoint.data());
57
58 if (respAddr->smctp_type != msgType)
59 {
60 lg2::error("MctpRequester: Message type mismatch");
61 completionCallback(EPROTO);
62 return;
63 }
64
65 uint8_t respEid = respAddr->smctp_addr.s_addr;
66
67 if (respEid != eid)
68 {
69 lg2::error(
70 "MctpRequester: EID mismatch - expected={EID}, received={REID}",
71 "EID", eid, "REID", respEid);
72 completionCallback(EPROTO);
73 return;
74 }
75
76 if (respMsg.size() > sizeof(ocp::accelerator_management::BindingPciVid))
77 {
78 const auto* reqHdr =
79 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
80 reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
81 reqMsg.data());
82
83 uint8_t reqInstanceId = reqHdr->instance_id &
84 ocp::accelerator_management::instanceIdBitMask;
85 const auto* respHdr =
86 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
87 reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
88 respMsg.data());
89
90 uint8_t respInstanceId = respHdr->instance_id &
91 ocp::accelerator_management::instanceIdBitMask;
92
93 if (reqInstanceId != respInstanceId)
94 {
95 lg2::error(
96 "MctpRequester: Instance ID mismatch - request={REQ}, response={RESP}",
97 "REQ", static_cast<int>(reqInstanceId), "RESP",
98 static_cast<int>(respInstanceId));
99 completionCallback(EPROTO);
100 return;
101 }
102 }
103
104 completionCallback(0);
105 }
106
handleSendMsgCompletion(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,const boost::system::error_code & ec,size_t)107 void MctpRequester::handleSendMsgCompletion(
108 uint8_t eid, const std::span<const uint8_t> reqMsg,
109 std::span<uint8_t> respMsg, const boost::system::error_code& ec,
110 size_t /* length */)
111 {
112 if (ec)
113 {
114 lg2::error(
115 "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
116 "EC", ec.value(), "ER", ec.message());
117 completionCallback(EIO);
118 return;
119 }
120
121 expiryTimer.expires_after(2s);
122
123 expiryTimer.async_wait([this](const boost::system::error_code& ec) {
124 if (ec != boost::asio::error::operation_aborted)
125 {
126 completionCallback(ETIME);
127 }
128 });
129
130 mctpSocket.async_receive_from(
131 boost::asio::mutable_buffer(respMsg), recvEndPoint,
132 std::bind_front(&MctpRequester::processRecvMsg, this, eid, reqMsg,
133 respMsg));
134 }
135
sendRecvMsg(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,std::move_only_function<void (int)> callback)136 void MctpRequester::sendRecvMsg(
137 uint8_t eid, const std::span<const uint8_t> reqMsg,
138 std::span<uint8_t> respMsg, std::move_only_function<void(int)> callback)
139 {
140 if (reqMsg.size() < sizeof(ocp::accelerator_management::BindingPciVid))
141 {
142 lg2::error("MctpRequester: Message too small");
143 callback(EPROTO);
144 return;
145 }
146
147 completionCallback = std::move(callback);
148
149 struct sockaddr_mctp addr{};
150 addr.smctp_family = AF_MCTP;
151 addr.smctp_addr.s_addr = eid;
152 addr.smctp_type = msgType;
153 addr.smctp_tag = MCTP_TAG_OWNER;
154
155 sendEndPoint = {&addr, sizeof(addr)};
156
157 mctpSocket.async_send_to(
158 boost::asio::const_buffer(reqMsg), sendEndPoint,
159 std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid,
160 reqMsg, respMsg));
161 }
162 } // namespace mctp
163