xref: /openbmc/dbus-sensors/src/nvidia-gpu/MctpRequester.cpp (revision 560e6af7b1f74e9c020a0f82817f9d926e0c4f72)
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION &
3  * AFFILIATES. All rights reserved.
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "MctpRequester.hpp"
8 
9 #include <linux/mctp.h>
10 #include <sys/socket.h>
11 
12 #include <OcpMctpVdm.hpp>
13 #include <boost/asio/buffer.hpp>
14 #include <boost/asio/error.hpp>
15 #include <boost/asio/generic/datagram_protocol.hpp>
16 #include <boost/asio/io_context.hpp>
17 #include <boost/asio/steady_timer.hpp>
18 #include <phosphor-logging/lg2.hpp>
19 
20 #include <cerrno>
21 #include <cstddef>
22 #include <cstdint>
23 #include <cstring>
24 #include <functional>
25 #include <span>
26 #include <utility>
27 
28 using namespace std::literals;
29 
30 namespace mctp
31 {
32 
MctpRequester(boost::asio::io_context & ctx)33 MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
34     mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}),
35     expiryTimer(ctx)
36 {}
37 
processRecvMsg(uint8_t eid,const std::span<const uint8_t> reqMsg,const std::span<uint8_t> respMsg,const boost::system::error_code & ec,const size_t)38 void MctpRequester::processRecvMsg(
39     uint8_t eid, const std::span<const uint8_t> reqMsg,
40     const std::span<uint8_t> respMsg, const boost::system::error_code& ec,
41     const size_t /*length*/)
42 {
43     expiryTimer.cancel();
44 
45     if (ec)
46     {
47         lg2::error(
48             "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
49             "EC", ec.value(), "ER", ec.message());
50         completionCallback(EIO);
51         return;
52     }
53 
54     const auto* respAddr =
55         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
56         reinterpret_cast<const struct sockaddr_mctp*>(recvEndPoint.data());
57 
58     if (respAddr->smctp_type != msgType)
59     {
60         lg2::error("MctpRequester: Message type mismatch");
61         completionCallback(EPROTO);
62         return;
63     }
64 
65     uint8_t respEid = respAddr->smctp_addr.s_addr;
66 
67     if (respEid != eid)
68     {
69         lg2::error(
70             "MctpRequester: EID mismatch - expected={EID}, received={REID}",
71             "EID", eid, "REID", respEid);
72         completionCallback(EPROTO);
73         return;
74     }
75 
76     if (respMsg.size() > sizeof(ocp::accelerator_management::BindingPciVid))
77     {
78         const auto* reqHdr =
79             // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
80             reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
81                 reqMsg.data());
82 
83         uint8_t reqInstanceId = reqHdr->instance_id &
84                                 ocp::accelerator_management::instanceIdBitMask;
85         const auto* respHdr =
86             // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
87             reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
88                 respMsg.data());
89 
90         uint8_t respInstanceId = respHdr->instance_id &
91                                  ocp::accelerator_management::instanceIdBitMask;
92 
93         if (reqInstanceId != respInstanceId)
94         {
95             lg2::error(
96                 "MctpRequester: Instance ID mismatch - request={REQ}, response={RESP}",
97                 "REQ", static_cast<int>(reqInstanceId), "RESP",
98                 static_cast<int>(respInstanceId));
99             completionCallback(EPROTO);
100             return;
101         }
102     }
103 
104     completionCallback(0);
105 }
106 
handleSendMsgCompletion(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,const boost::system::error_code & ec,size_t)107 void MctpRequester::handleSendMsgCompletion(
108     uint8_t eid, const std::span<const uint8_t> reqMsg,
109     std::span<uint8_t> respMsg, const boost::system::error_code& ec,
110     size_t /* length */)
111 {
112     if (ec)
113     {
114         lg2::error(
115             "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
116             "EC", ec.value(), "ER", ec.message());
117         completionCallback(EIO);
118         return;
119     }
120 
121     expiryTimer.expires_after(2s);
122 
123     expiryTimer.async_wait([this](const boost::system::error_code& ec) {
124         if (ec != boost::asio::error::operation_aborted)
125         {
126             completionCallback(ETIME);
127         }
128     });
129 
130     mctpSocket.async_receive_from(
131         boost::asio::mutable_buffer(respMsg), recvEndPoint,
132         std::bind_front(&MctpRequester::processRecvMsg, this, eid, reqMsg,
133                         respMsg));
134 }
135 
sendRecvMsg(uint8_t eid,const std::span<const uint8_t> reqMsg,std::span<uint8_t> respMsg,std::move_only_function<void (int)> callback)136 void MctpRequester::sendRecvMsg(
137     uint8_t eid, const std::span<const uint8_t> reqMsg,
138     std::span<uint8_t> respMsg, std::move_only_function<void(int)> callback)
139 {
140     if (reqMsg.size() < sizeof(ocp::accelerator_management::BindingPciVid))
141     {
142         lg2::error("MctpRequester: Message too small");
143         callback(EPROTO);
144         return;
145     }
146 
147     completionCallback = std::move(callback);
148 
149     struct sockaddr_mctp addr{};
150     addr.smctp_family = AF_MCTP;
151     addr.smctp_addr.s_addr = eid;
152     addr.smctp_type = msgType;
153     addr.smctp_tag = MCTP_TAG_OWNER;
154 
155     sendEndPoint = {&addr, sizeof(addr)};
156 
157     mctpSocket.async_send_to(
158         boost::asio::const_buffer(reqMsg), sendEndPoint,
159         std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid,
160                         reqMsg, respMsg));
161 }
162 } // namespace mctp
163