/* * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ #include "MctpRequester.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std::literals; namespace mctp { Requester::Requester(boost::asio::io_context& ctx) : mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}), expiryTimer(ctx) {} void Requester::processRecvMsg( const std::span reqMsg, const std::span respMsg, const boost::system::error_code& ec, const size_t /*length*/) { const auto* respAddr = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) reinterpret_cast(recvEndPoint.data()); uint8_t eid = respAddr->smctp_addr.s_addr; if (!completionCallbacks.contains(eid)) { lg2::error( "MctpRequester failed to get the callback for the EID: {EID}", "EID", static_cast(eid)); return; } auto& callback = completionCallbacks.at(eid); if (respAddr->smctp_type != msgType) { lg2::error("MctpRequester: Message type mismatch"); callback(EPROTO); return; } expiryTimer.cancel(); if (ec) { lg2::error( "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.", "EC", ec.value(), "ER", ec.message()); callback(EIO); return; } if (respMsg.size() > sizeof(ocp::accelerator_management::BindingPciVid)) { const auto* reqHdr = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) reinterpret_cast( reqMsg.data()); uint8_t reqInstanceId = reqHdr->instance_id & ocp::accelerator_management::instanceIdBitMask; const auto* respHdr = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) reinterpret_cast( respMsg.data()); uint8_t respInstanceId = respHdr->instance_id & ocp::accelerator_management::instanceIdBitMask; if (reqInstanceId != respInstanceId) { lg2::error( "MctpRequester: Instance ID mismatch - request={REQ}, response={RESP}", "REQ", static_cast(reqInstanceId), "RESP", static_cast(respInstanceId)); callback(EPROTO); return; } } callback(0); } void Requester::handleSendMsgCompletion( uint8_t eid, const std::span reqMsg, std::span respMsg, const boost::system::error_code& ec, size_t /* length */) { if (!completionCallbacks.contains(eid)) { lg2::error( "MctpRequester failed to get the callback for the EID: {EID}", "EID", static_cast(eid)); return; } auto& callback = completionCallbacks.at(eid); if (ec) { lg2::error( "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.", "EC", ec.value(), "ER", ec.message()); callback(EIO); return; } expiryTimer.expires_after(2s); expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) { if (ec != boost::asio::error::operation_aborted) { auto& callback = completionCallbacks.at(eid); callback(ETIME); } }); mctpSocket.async_receive_from( boost::asio::mutable_buffer(respMsg), recvEndPoint, std::bind_front(&Requester::processRecvMsg, this, reqMsg, respMsg)); } void Requester::sendRecvMsg(uint8_t eid, const std::span reqMsg, std::span respMsg, std::move_only_function callback) { if (reqMsg.size() < sizeof(ocp::accelerator_management::BindingPciVid)) { lg2::error("MctpRequester: Message too small"); callback(EPROTO); return; } completionCallbacks[eid] = std::move(callback); struct sockaddr_mctp addr{}; addr.smctp_family = AF_MCTP; addr.smctp_addr.s_addr = eid; addr.smctp_type = msgType; addr.smctp_tag = MCTP_TAG_OWNER; sendEndPoint = {&addr, sizeof(addr)}; mctpSocket.async_send_to( boost::asio::const_buffer(reqMsg), sendEndPoint, std::bind_front(&Requester::handleSendMsgCompletion, this, eid, reqMsg, respMsg)); } void QueuingRequester::sendRecvMsg(uint8_t eid, std::span reqMsg, std::span respMsg, std::move_only_function callback) { auto reqCtx = std::make_unique(reqMsg, respMsg, std::move(callback)); // Add request to queue auto& queue = requestContextQueues[eid]; queue.push_back(std::move(reqCtx)); if (queue.size() == 1) { processQueue(eid); } } void QueuingRequester::handleResult(uint8_t eid, int result) { auto& queue = requestContextQueues[eid]; const auto& reqCtx = queue.front(); reqCtx->callback(result); // Call the original callback queue.pop_front(); processQueue(eid); } void QueuingRequester::processQueue(uint8_t eid) { auto& queue = requestContextQueues[eid]; if (queue.empty()) { return; } const auto& reqCtx = queue.front(); requester.sendRecvMsg( eid, reqCtx->reqMsg, reqCtx->respMsg, std::bind_front(&QueuingRequester::handleResult, this, eid)); } } // namespace mctp