1 #include "common/instance_id.hpp"
2 #include "common/types.hpp"
3 #include "common/utils.hpp"
4 #include "mock_request.hpp"
5 #include "requester/handler.hpp"
6 #include "test/test_instance_id.hpp"
7 
8 #include <libpldm/base.h>
9 #include <libpldm/transport.h>
10 
11 #include <sdbusplus/async.hpp>
12 
13 #include <gmock/gmock.h>
14 #include <gtest/gtest.h>
15 
16 using namespace pldm::requester;
17 using namespace std::chrono;
18 
19 using ::testing::AtLeast;
20 using ::testing::Between;
21 using ::testing::Exactly;
22 using ::testing::NiceMock;
23 using ::testing::Return;
24 
25 class HandlerTest : public testing::Test
26 {
27   protected:
28     HandlerTest() : event(sdeventplus::Event::get_default()), instanceIdDb() {}
29 
30     int fd = 0;
31     mctp_eid_t eid = 0;
32     PldmTransport* pldmTransport = nullptr;
33     sdeventplus::Event event;
34     TestInstanceIdDb instanceIdDb;
35 
36     /** @brief This function runs the sd_event_run in a loop till all the events
37      *         in the testcase are dispatched and exits when there are no events
38      *         for the timeout time.
39      *
40      *  @param[in] timeout - maximum time to wait for an event
41      */
42     void waitEventExpiry(milliseconds timeout)
43     {
44         while (1)
45         {
46             auto sleepTime = duration_cast<microseconds>(timeout);
47             // Returns 0 on timeout
48             if (!sd_event_run(event.get(), sleepTime.count()))
49             {
50                 break;
51             }
52         }
53     }
54 
55   public:
56     bool nullResponse = false;
57     bool validResponse = false;
58     int callbackCount = 0;
59     bool response2 = false;
60 
61     void pldmResponseCallBack(mctp_eid_t /*eid*/, const pldm_msg* response,
62                               size_t respMsgLen)
63     {
64         if (response == nullptr && respMsgLen == 0)
65         {
66             nullResponse = true;
67         }
68         else
69         {
70             validResponse = true;
71         }
72         callbackCount++;
73     }
74 };
75 
76 TEST_F(HandlerTest, singleRequestResponseScenario)
77 {
78     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
79                                               instanceIdDb, false, seconds(1),
80                                               2, milliseconds(100));
81     pldm::Request request{};
82     auto instanceId = instanceIdDb.next(eid);
83     EXPECT_EQ(instanceId, 0);
84     auto rc = reqHandler.registerRequest(
85         eid, instanceId, 0, 0, std::move(request),
86         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
87     EXPECT_EQ(rc, PLDM_SUCCESS);
88 
89     pldm::Response response(sizeof(pldm_msg_hdr) + sizeof(uint8_t));
90     auto responsePtr = reinterpret_cast<const pldm_msg*>(response.data());
91     reqHandler.handleResponse(eid, instanceId, 0, 0, responsePtr,
92                               response.size());
93 
94     EXPECT_EQ(validResponse, true);
95 }
96 
97 TEST_F(HandlerTest, singleRequestInstanceIdTimerExpired)
98 {
99     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
100                                               instanceIdDb, false, seconds(1),
101                                               2, milliseconds(100));
102     pldm::Request request{};
103     auto instanceId = instanceIdDb.next(eid);
104     EXPECT_EQ(instanceId, 0);
105     auto rc = reqHandler.registerRequest(
106         eid, instanceId, 0, 0, std::move(request),
107         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
108     EXPECT_EQ(rc, PLDM_SUCCESS);
109 
110     // Waiting for 500ms so that the instance ID expiry callback is invoked
111     waitEventExpiry(milliseconds(500));
112 
113     EXPECT_EQ(nullResponse, true);
114 }
115 
116 TEST_F(HandlerTest, multipleRequestResponseScenario)
117 {
118     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
119                                               instanceIdDb, false, seconds(2),
120                                               2, milliseconds(100));
121     pldm::Request request{};
122     auto instanceId = instanceIdDb.next(eid);
123     EXPECT_EQ(instanceId, 0);
124     auto rc = reqHandler.registerRequest(
125         eid, instanceId, 0, 0, std::move(request),
126         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
127     EXPECT_EQ(rc, PLDM_SUCCESS);
128 
129     pldm::Request requestNxt{};
130     auto instanceIdNxt = instanceIdDb.next(eid);
131     EXPECT_EQ(instanceIdNxt, 1);
132     rc = reqHandler.registerRequest(
133         eid, instanceIdNxt, 0, 0, std::move(requestNxt),
134         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
135     EXPECT_EQ(rc, PLDM_SUCCESS);
136 
137     pldm::Response response(sizeof(pldm_msg_hdr) + sizeof(uint8_t));
138     auto responsePtr = reinterpret_cast<const pldm_msg*>(response.data());
139     reqHandler.handleResponse(eid, instanceId, 0, 0, responsePtr,
140                               response.size());
141     EXPECT_EQ(validResponse, true);
142     EXPECT_EQ(callbackCount, 1);
143     validResponse = false;
144 
145     // Waiting for 500ms and handle the response for the first request, to
146     // simulate a delayed response for the first request
147     waitEventExpiry(milliseconds(500));
148 
149     reqHandler.handleResponse(eid, instanceIdNxt, 0, 0, responsePtr,
150                               response.size());
151 
152     EXPECT_EQ(validResponse, true);
153     EXPECT_EQ(callbackCount, 2);
154 }
155 
156 TEST_F(HandlerTest, singleRequestResponseScenarioUsingCoroutine)
157 {
158     exec::async_scope scope;
159     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
160                                               instanceIdDb, false, seconds(1),
161                                               2, milliseconds(100));
162 
163     auto instanceId = instanceIdDb.next(eid);
164     EXPECT_EQ(instanceId, 0);
165 
166     scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
167         pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
168         const pldm_msg* responseMsg;
169         size_t responseLen;
170         int rc = PLDM_SUCCESS;
171 
172         auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
173         requestPtr->hdr.instance_id = instanceId;
174 
175         try
176         {
177             std::tie(rc, responseMsg, responseLen) =
178                 co_await reqHandler.sendRecvMsg(eid, std::move(request));
179         }
180         catch (...)
181         {
182             std::rethrow_exception(std::current_exception());
183         }
184 
185         EXPECT_NE(responseLen, 0);
186 
187         this->pldmResponseCallBack(eid, responseMsg, responseLen);
188 
189         EXPECT_EQ(validResponse, true);
190     }),
191                 exec::default_task_context<void>());
192 
193     pldm::Response mockResponse(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
194     auto mockResponsePtr =
195         reinterpret_cast<const pldm_msg*>(mockResponse.data());
196     reqHandler.handleResponse(eid, instanceId, 0, 0, mockResponsePtr,
197                               mockResponse.size() - sizeof(pldm_msg_hdr));
198 
199     stdexec::sync_wait(scope.on_empty());
200 }
201 
202 TEST_F(HandlerTest, singleRequestCancellationScenarioUsingCoroutine)
203 {
204     exec::async_scope scope;
205     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
206                                               instanceIdDb, false, seconds(1),
207                                               2, milliseconds(100));
208     auto instanceId = instanceIdDb.next(eid);
209     EXPECT_EQ(instanceId, 0);
210 
211     bool stopped = false;
212 
213     scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
214         pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
215         pldm::Response response;
216 
217         auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
218         requestPtr->hdr.instance_id = instanceId;
219 
220         co_await reqHandler.sendRecvMsg(eid, std::move(request));
221 
222         EXPECT_TRUE(false); // unreachable
223     }) | stdexec::upon_stopped([&] { stopped = true; }),
224                 exec::default_task_context<void>());
225 
226     scope.request_stop();
227 
228     EXPECT_TRUE(stopped);
229 
230     stdexec::sync_wait(scope.on_empty());
231 }
232 
233 TEST_F(HandlerTest, asyncRequestResponseByCoroutine)
234 {
235     struct _
236     {
237         static exec::task<uint8_t> getTIDTask(Handler<MockRequest>& handler,
238                                               mctp_eid_t eid,
239                                               uint8_t instanceId, uint8_t& tid)
240         {
241             pldm::Request request(sizeof(pldm_msg_hdr), 0);
242             auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
243             const pldm_msg* responseMsg;
244             size_t responseLen;
245 
246             auto rc = encode_get_tid_req(instanceId, requestMsg);
247             EXPECT_EQ(rc, PLDM_SUCCESS);
248 
249             std::tie(rc, responseMsg, responseLen) =
250                 co_await handler.sendRecvMsg(eid, std::move(request));
251             EXPECT_NE(responseLen, 0);
252 
253             uint8_t cc = 0;
254             rc = decode_get_tid_resp(responseMsg, responseLen, &cc, &tid);
255             EXPECT_EQ(rc, PLDM_SUCCESS);
256 
257             co_return cc;
258         }
259     };
260 
261     exec::async_scope scope;
262     Handler<MockRequest> reqHandler(pldmTransport, event, instanceIdDb, false,
263                                     seconds(1), 2, milliseconds(100));
264     auto instanceId = instanceIdDb.next(eid);
265 
266     uint8_t expectedTid = 1;
267 
268     // Execute a coroutine to send getTID command. The coroutine is suspended
269     // until reqHandler.handleResponse() is received.
270     scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
271         uint8_t respTid = 0;
272 
273         co_await _::getTIDTask(reqHandler, eid, instanceId, respTid);
274 
275         EXPECT_EQ(expectedTid, respTid);
276     }),
277                 exec::default_task_context<void>());
278 
279     pldm::Response mockResponse(sizeof(pldm_msg_hdr) + PLDM_GET_TID_RESP_BYTES,
280                                 0);
281     auto mockResponseMsg = reinterpret_cast<pldm_msg*>(mockResponse.data());
282 
283     // Compose response message of getTID command
284     encode_get_tid_resp(instanceId, PLDM_SUCCESS, expectedTid, mockResponseMsg);
285 
286     // Send response back to resume getTID coroutine to update respTid by
287     // calling  reqHandler.handleResponse() manually
288     reqHandler.handleResponse(eid, instanceId, PLDM_BASE, PLDM_GET_TID,
289                               mockResponseMsg,
290                               mockResponse.size() - sizeof(pldm_msg_hdr));
291 
292     stdexec::sync_wait(scope.on_empty());
293 }
294