1 #include "common/instance_id.hpp"
2 #include "common/types.hpp"
3 #include "common/utils.hpp"
4 #include "mock_request.hpp"
5 #include "requester/handler.hpp"
6 #include "test/test_instance_id.hpp"
7 
8 #include <libpldm/base.h>
9 #include <libpldm/transport.h>
10 
11 #include <sdbusplus/async.hpp>
12 
13 #include <gmock/gmock.h>
14 #include <gtest/gtest.h>
15 
16 using namespace pldm::requester;
17 using namespace std::chrono;
18 
19 using ::testing::AtLeast;
20 using ::testing::Between;
21 using ::testing::Exactly;
22 using ::testing::NiceMock;
23 using ::testing::Return;
24 
25 class HandlerTest : public testing::Test
26 {
27   protected:
28     HandlerTest() : event(sdeventplus::Event::get_default()), instanceIdDb() {}
29 
30     int fd = 0;
31     mctp_eid_t eid = 0;
32     PldmTransport* pldmTransport = nullptr;
33     sdeventplus::Event event;
34     TestInstanceIdDb instanceIdDb;
35 
36     /** @brief This function runs the sd_event_run in a loop till all the events
37      *         in the testcase are dispatched and exits when there are no events
38      *         for the timeout time.
39      *
40      *  @param[in] timeout - maximum time to wait for an event
41      */
42     void waitEventExpiry(milliseconds timeout)
43     {
44         while (1)
45         {
46             auto sleepTime = duration_cast<microseconds>(timeout);
47             // Returns 0 on timeout
48             if (!sd_event_run(event.get(), sleepTime.count()))
49             {
50                 break;
51             }
52         }
53     }
54 
55   public:
56     bool nullResponse = false;
57     bool validResponse = false;
58     int callbackCount = 0;
59     bool response2 = false;
60 
61     void pldmResponseCallBack(mctp_eid_t /*eid*/, const pldm_msg* response,
62                               size_t respMsgLen)
63     {
64         if (response == nullptr && respMsgLen == 0)
65         {
66             nullResponse = true;
67         }
68         else
69         {
70             validResponse = true;
71         }
72         callbackCount++;
73     }
74 };
75 
76 TEST_F(HandlerTest, singleRequestResponseScenario)
77 {
78     Handler<NiceMock<MockRequest>> reqHandler(
79         pldmTransport, event, instanceIdDb, false, seconds(1), 2,
80         milliseconds(100));
81     pldm::Request request{};
82     auto instanceId = instanceIdDb.next(eid);
83     EXPECT_EQ(instanceId, 0);
84     auto rc = reqHandler.registerRequest(
85         eid, instanceId, 0, 0, std::move(request),
86         std::bind_front(&HandlerTest::pldmResponseCallBack, this));
87     EXPECT_EQ(rc, PLDM_SUCCESS);
88 
89     pldm::Response response(sizeof(pldm_msg_hdr) + sizeof(uint8_t));
90     auto responsePtr = reinterpret_cast<const pldm_msg*>(response.data());
91     reqHandler.handleResponse(eid, instanceId, 0, 0, responsePtr,
92                               response.size());
93 
94     EXPECT_EQ(validResponse, true);
95 }
96 
97 TEST_F(HandlerTest, singleRequestInstanceIdTimerExpired)
98 {
99     Handler<NiceMock<MockRequest>> reqHandler(
100         pldmTransport, event, instanceIdDb, false, seconds(1), 2,
101         milliseconds(100));
102     pldm::Request request{};
103     auto instanceId = instanceIdDb.next(eid);
104     EXPECT_EQ(instanceId, 0);
105     auto rc = reqHandler.registerRequest(
106         eid, instanceId, 0, 0, std::move(request),
107         std::bind_front(&HandlerTest::pldmResponseCallBack, this));
108     EXPECT_EQ(rc, PLDM_SUCCESS);
109 
110     // Waiting for 500ms so that the instance ID expiry callback is invoked
111     waitEventExpiry(milliseconds(500));
112 
113     EXPECT_EQ(nullResponse, true);
114 }
115 
116 TEST_F(HandlerTest, multipleRequestResponseScenario)
117 {
118     Handler<NiceMock<MockRequest>> reqHandler(
119         pldmTransport, event, instanceIdDb, false, seconds(2), 2,
120         milliseconds(100));
121     pldm::Request request{};
122     auto instanceId = instanceIdDb.next(eid);
123     EXPECT_EQ(instanceId, 0);
124     auto rc = reqHandler.registerRequest(
125         eid, instanceId, 0, 0, std::move(request),
126         std::bind_front(&HandlerTest::pldmResponseCallBack, this));
127     EXPECT_EQ(rc, PLDM_SUCCESS);
128 
129     pldm::Request requestNxt{};
130     auto instanceIdNxt = instanceIdDb.next(eid);
131     EXPECT_EQ(instanceIdNxt, 1);
132     rc = reqHandler.registerRequest(
133         eid, instanceIdNxt, 0, 0, std::move(requestNxt),
134         std::bind_front(&HandlerTest::pldmResponseCallBack, this));
135     EXPECT_EQ(rc, PLDM_SUCCESS);
136 
137     pldm::Response response(sizeof(pldm_msg_hdr) + sizeof(uint8_t));
138     auto responsePtr = reinterpret_cast<const pldm_msg*>(response.data());
139     reqHandler.handleResponse(eid, instanceId, 0, 0, responsePtr,
140                               response.size());
141     EXPECT_EQ(validResponse, true);
142     EXPECT_EQ(callbackCount, 1);
143     validResponse = false;
144 
145     // Waiting for 500ms and handle the response for the first request, to
146     // simulate a delayed response for the first request
147     waitEventExpiry(milliseconds(500));
148 
149     reqHandler.handleResponse(eid, instanceIdNxt, 0, 0, responsePtr,
150                               response.size());
151 
152     EXPECT_EQ(validResponse, true);
153     EXPECT_EQ(callbackCount, 2);
154 }
155 
156 TEST_F(HandlerTest, singleRequestResponseScenarioUsingCoroutine)
157 {
158     exec::async_scope scope;
159     Handler<NiceMock<MockRequest>> reqHandler(
160         pldmTransport, event, instanceIdDb, false, seconds(1), 2,
161         milliseconds(100));
162 
163     auto instanceId = instanceIdDb.next(eid);
164     EXPECT_EQ(instanceId, 0);
165 
166     scope.spawn(
167         stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
168             pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
169             const pldm_msg* responseMsg;
170             size_t responseLen;
171             int rc = PLDM_SUCCESS;
172 
173             auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
174             requestPtr->hdr.instance_id = instanceId;
175 
176             try
177             {
178                 std::tie(rc, responseMsg, responseLen) =
179                     co_await reqHandler.sendRecvMsg(eid, std::move(request));
180             }
181             catch (...)
182             {
183                 std::rethrow_exception(std::current_exception());
184             }
185 
186             EXPECT_NE(responseLen, 0);
187 
188             this->pldmResponseCallBack(eid, responseMsg, responseLen);
189 
190             EXPECT_EQ(validResponse, true);
191         }),
192         exec::default_task_context<void>(exec::inline_scheduler{}));
193 
194     pldm::Response mockResponse(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
195     auto mockResponsePtr =
196         reinterpret_cast<const pldm_msg*>(mockResponse.data());
197     reqHandler.handleResponse(eid, instanceId, 0, 0, mockResponsePtr,
198                               mockResponse.size() - sizeof(pldm_msg_hdr));
199 
200     stdexec::sync_wait(scope.on_empty());
201 }
202 
203 TEST_F(HandlerTest, singleRequestCancellationScenarioUsingCoroutine)
204 {
205     exec::async_scope scope;
206     Handler<NiceMock<MockRequest>> reqHandler(
207         pldmTransport, event, instanceIdDb, false, seconds(1), 2,
208         milliseconds(100));
209     auto instanceId = instanceIdDb.next(eid);
210     EXPECT_EQ(instanceId, 0);
211 
212     bool stopped = false;
213 
214     scope.spawn(
215         stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
216             pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
217             pldm::Response response;
218 
219             auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
220             requestPtr->hdr.instance_id = instanceId;
221 
222             co_await reqHandler.sendRecvMsg(eid, std::move(request));
223 
224             EXPECT_TRUE(false); // unreachable
225         }) | stdexec::upon_stopped([&] { stopped = true; }),
226         exec::default_task_context<void>(exec::inline_scheduler{}));
227 
228     scope.request_stop();
229 
230     EXPECT_TRUE(stopped);
231 
232     stdexec::sync_wait(scope.on_empty());
233 }
234 
235 TEST_F(HandlerTest, asyncRequestResponseByCoroutine)
236 {
237     struct _
238     {
239         static exec::task<uint8_t>
240             getTIDTask(Handler<MockRequest>& handler, mctp_eid_t eid,
241                        uint8_t instanceId, uint8_t& tid)
242         {
243             pldm::Request request(sizeof(pldm_msg_hdr), 0);
244             auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
245             const pldm_msg* responseMsg;
246             size_t responseLen;
247 
248             auto rc = encode_get_tid_req(instanceId, requestMsg);
249             EXPECT_EQ(rc, PLDM_SUCCESS);
250 
251             std::tie(rc, responseMsg, responseLen) =
252                 co_await handler.sendRecvMsg(eid, std::move(request));
253             EXPECT_NE(responseLen, 0);
254 
255             uint8_t cc = 0;
256             rc = decode_get_tid_resp(responseMsg, responseLen, &cc, &tid);
257             EXPECT_EQ(rc, PLDM_SUCCESS);
258 
259             co_return cc;
260         }
261     };
262 
263     exec::async_scope scope;
264     Handler<MockRequest> reqHandler(pldmTransport, event, instanceIdDb, false,
265                                     seconds(1), 2, milliseconds(100));
266     auto instanceId = instanceIdDb.next(eid);
267 
268     uint8_t expectedTid = 1;
269 
270     // Execute a coroutine to send getTID command. The coroutine is suspended
271     // until reqHandler.handleResponse() is received.
272     scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
273                     uint8_t respTid = 0;
274 
275                     co_await _::getTIDTask(reqHandler, eid, instanceId,
276                                            respTid);
277 
278                     EXPECT_EQ(expectedTid, respTid);
279                 }),
280                 exec::default_task_context<void>(exec::inline_scheduler{}));
281 
282     pldm::Response mockResponse(sizeof(pldm_msg_hdr) + PLDM_GET_TID_RESP_BYTES,
283                                 0);
284     auto mockResponseMsg = reinterpret_cast<pldm_msg*>(mockResponse.data());
285 
286     // Compose response message of getTID command
287     encode_get_tid_resp(instanceId, PLDM_SUCCESS, expectedTid, mockResponseMsg);
288 
289     // Send response back to resume getTID coroutine to update respTid by
290     // calling  reqHandler.handleResponse() manually
291     reqHandler.handleResponse(eid, instanceId, PLDM_BASE, PLDM_GET_TID,
292                               mockResponseMsg,
293                               mockResponse.size() - sizeof(pldm_msg_hdr));
294 
295     stdexec::sync_wait(scope.on_empty());
296 }
297