1 #include "common/instance_id.hpp"
2 #include "common/types.hpp"
3 #include "common/utils.hpp"
4 #include "mock_request.hpp"
5 #include "requester/handler.hpp"
6 #include "test/test_instance_id.hpp"
7 
8 #include <libpldm/base.h>
9 #include <libpldm/transport.h>
10 
11 #include <sdbusplus/async.hpp>
12 
13 #include <gmock/gmock.h>
14 #include <gtest/gtest.h>
15 
16 using namespace pldm::requester;
17 using namespace std::chrono;
18 
19 using ::testing::AtLeast;
20 using ::testing::Between;
21 using ::testing::Exactly;
22 using ::testing::NiceMock;
23 using ::testing::Return;
24 
25 class HandlerTest : public testing::Test
26 {
27   protected:
28     HandlerTest() : event(sdeventplus::Event::get_default()), instanceIdDb() {}
29 
30     int fd = 0;
31     mctp_eid_t eid = 0;
32     PldmTransport* pldmTransport = nullptr;
33     sdeventplus::Event event;
34     TestInstanceIdDb instanceIdDb;
35 
36     /** @brief This function runs the sd_event_run in a loop till all the events
37      *         in the testcase are dispatched and exits when there are no events
38      *         for the timeout time.
39      *
40      *  @param[in] timeout - maximum time to wait for an event
41      */
42     void waitEventExpiry(milliseconds timeout)
43     {
44         while (1)
45         {
46             auto sleepTime = duration_cast<microseconds>(timeout);
47             // Returns 0 on timeout
48             if (!sd_event_run(event.get(), sleepTime.count()))
49             {
50                 break;
51             }
52         }
53     }
54 
55   public:
56     bool nullResponse = false;
57     bool validResponse = false;
58     int callbackCount = 0;
59     bool response2 = false;
60 
61     void pldmResponseCallBack(mctp_eid_t /*eid*/, const pldm_msg* response,
62                               size_t respMsgLen)
63     {
64         if (response == nullptr && respMsgLen == 0)
65         {
66             nullResponse = true;
67         }
68         else
69         {
70             validResponse = true;
71         }
72         callbackCount++;
73     }
74 };
75 
76 TEST_F(HandlerTest, singleRequestResponseScenario)
77 {
78     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
79                                               instanceIdDb, false, seconds(1),
80                                               2, milliseconds(100));
81     pldm::Request request{};
82     auto instanceId = instanceIdDb.next(eid);
83     EXPECT_EQ(instanceId, 0);
84     auto rc = reqHandler.registerRequest(
85         eid, instanceId, 0, 0, std::move(request),
86         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
87     EXPECT_EQ(rc, PLDM_SUCCESS);
88 
89     pldm::Response response(sizeof(pldm_msg_hdr) + sizeof(uint8_t));
90     auto responsePtr = reinterpret_cast<const pldm_msg*>(response.data());
91     reqHandler.handleResponse(eid, instanceId, 0, 0, responsePtr,
92                               response.size());
93 
94     EXPECT_EQ(validResponse, true);
95 }
96 
97 TEST_F(HandlerTest, singleRequestInstanceIdTimerExpired)
98 {
99     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
100                                               instanceIdDb, false, seconds(1),
101                                               2, milliseconds(100));
102     pldm::Request request{};
103     auto instanceId = instanceIdDb.next(eid);
104     EXPECT_EQ(instanceId, 0);
105     auto rc = reqHandler.registerRequest(
106         eid, instanceId, 0, 0, std::move(request),
107         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
108     EXPECT_EQ(rc, PLDM_SUCCESS);
109 
110     // Waiting for 500ms so that the instance ID expiry callback is invoked
111     waitEventExpiry(milliseconds(500));
112 
113     EXPECT_EQ(nullResponse, true);
114 }
115 
116 TEST_F(HandlerTest, multipleRequestResponseScenario)
117 {
118     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
119                                               instanceIdDb, false, seconds(2),
120                                               2, milliseconds(100));
121     pldm::Request request{};
122     auto instanceId = instanceIdDb.next(eid);
123     EXPECT_EQ(instanceId, 0);
124     auto rc = reqHandler.registerRequest(
125         eid, instanceId, 0, 0, std::move(request),
126         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
127     EXPECT_EQ(rc, PLDM_SUCCESS);
128 
129     pldm::Request requestNxt{};
130     auto instanceIdNxt = instanceIdDb.next(eid);
131     EXPECT_EQ(instanceIdNxt, 1);
132     rc = reqHandler.registerRequest(
133         eid, instanceIdNxt, 0, 0, std::move(requestNxt),
134         std::move(std::bind_front(&HandlerTest::pldmResponseCallBack, this)));
135     EXPECT_EQ(rc, PLDM_SUCCESS);
136 
137     pldm::Response response(sizeof(pldm_msg_hdr) + sizeof(uint8_t));
138     auto responsePtr = reinterpret_cast<const pldm_msg*>(response.data());
139     reqHandler.handleResponse(eid, instanceId, 0, 0, responsePtr,
140                               response.size());
141     EXPECT_EQ(validResponse, true);
142     EXPECT_EQ(callbackCount, 1);
143     validResponse = false;
144 
145     // Waiting for 500ms and handle the response for the first request, to
146     // simulate a delayed response for the first request
147     waitEventExpiry(milliseconds(500));
148 
149     reqHandler.handleResponse(eid, instanceIdNxt, 0, 0, responsePtr,
150                               response.size());
151 
152     EXPECT_EQ(validResponse, true);
153     EXPECT_EQ(callbackCount, 2);
154 }
155 
156 TEST_F(HandlerTest, singleRequestResponseScenarioUsingCoroutine)
157 {
158     exec::async_scope scope;
159     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
160                                               instanceIdDb, false, seconds(1),
161                                               2, milliseconds(100));
162 
163     auto instanceId = instanceIdDb.next(eid);
164     EXPECT_EQ(instanceId, 0);
165 
166     scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
167         pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
168         const pldm_msg* responseMsg;
169         size_t responseLen;
170 
171         auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
172         requestPtr->hdr.instance_id = instanceId;
173 
174         try
175         {
176             std::tie(responseMsg, responseLen) =
177                 co_await reqHandler.sendRecvMsg(eid, std::move(request));
178         }
179         catch (...)
180         {
181             std::rethrow_exception(std::current_exception());
182         }
183 
184         EXPECT_NE(responseLen, 0);
185 
186         this->pldmResponseCallBack(eid, responseMsg, responseLen);
187 
188         EXPECT_EQ(validResponse, true);
189     }),
190                 exec::default_task_context<void>());
191 
192     pldm::Response mockResponse(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
193     auto mockResponsePtr =
194         reinterpret_cast<const pldm_msg*>(mockResponse.data());
195     reqHandler.handleResponse(eid, instanceId, 0, 0, mockResponsePtr,
196                               mockResponse.size() - sizeof(pldm_msg_hdr));
197 
198     stdexec::sync_wait(scope.on_empty());
199 }
200 
201 TEST_F(HandlerTest, singleRequestCancellationScenarioUsingCoroutine)
202 {
203     exec::async_scope scope;
204     Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
205                                               instanceIdDb, false, seconds(1),
206                                               2, milliseconds(100));
207     auto instanceId = instanceIdDb.next(eid);
208     EXPECT_EQ(instanceId, 0);
209 
210     bool stopped = false;
211 
212     scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
213         pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
214         pldm::Response response;
215 
216         auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
217         requestPtr->hdr.instance_id = instanceId;
218 
219         co_await reqHandler.sendRecvMsg(eid, std::move(request));
220 
221         EXPECT_TRUE(false); // unreachable
222     }) | stdexec::upon_stopped([&] { stopped = true; }),
223                 exec::default_task_context<void>());
224 
225     scope.request_stop();
226 
227     EXPECT_TRUE(stopped);
228 
229     stdexec::sync_wait(scope.on_empty());
230 }
231 
232 TEST_F(HandlerTest, asyncRequestResponseByCoroutine)
233 {
234     struct _
235     {
236         static exec::task<uint8_t> getTIDTask(Handler<MockRequest>& handler,
237                                               mctp_eid_t eid,
238                                               uint8_t instanceId, uint8_t& tid)
239         {
240             pldm::Request request(sizeof(pldm_msg_hdr), 0);
241             auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
242             const pldm_msg* responseMsg;
243             size_t responseLen;
244 
245             auto rc = encode_get_tid_req(instanceId, requestMsg);
246             EXPECT_EQ(rc, PLDM_SUCCESS);
247 
248             std::tie(responseMsg, responseLen) =
249                 co_await handler.sendRecvMsg(eid, std::move(request));
250             EXPECT_NE(responseLen, 0);
251 
252             uint8_t cc = 0;
253             rc = decode_get_tid_resp(responseMsg, responseLen, &cc, &tid);
254             EXPECT_EQ(rc, PLDM_SUCCESS);
255 
256             co_return cc;
257         }
258     };
259 
260     exec::async_scope scope;
261     Handler<MockRequest> reqHandler(pldmTransport, event, instanceIdDb, false,
262                                     seconds(1), 2, milliseconds(100));
263     auto instanceId = instanceIdDb.next(eid);
264 
265     uint8_t expectedTid = 1;
266 
267     // Execute a coroutine to send getTID command. The coroutine is suspended
268     // until reqHandler.handleResponse() is received.
269     scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
270         uint8_t respTid = 0;
271 
272         co_await _::getTIDTask(reqHandler, eid, instanceId, respTid);
273 
274         EXPECT_EQ(expectedTid, respTid);
275     }),
276                 exec::default_task_context<void>());
277 
278     pldm::Response mockResponse(sizeof(pldm_msg_hdr) + PLDM_GET_TID_RESP_BYTES,
279                                 0);
280     auto mockResponseMsg = reinterpret_cast<pldm_msg*>(mockResponse.data());
281 
282     // Compose response message of getTID command
283     encode_get_tid_resp(instanceId, PLDM_SUCCESS, expectedTid, mockResponseMsg);
284 
285     // Send response back to resume getTID coroutine to update respTid by
286     // calling  reqHandler.handleResponse() manually
287     reqHandler.handleResponse(eid, instanceId, PLDM_BASE, PLDM_GET_TID,
288                               mockResponseMsg,
289                               mockResponse.size() - sizeof(pldm_msg_hdr));
290 
291     stdexec::sync_wait(scope.on_empty());
292 }
293