1 #include "terminus_manager.hpp"
2 
3 #include "manager.hpp"
4 
5 #include <phosphor-logging/lg2.hpp>
6 
7 PHOSPHOR_LOG2_USING;
8 
9 namespace pldm
10 {
11 namespace platform_mc
12 {
13 
14 std::optional<MctpInfo> TerminusManager::toMctpInfo(const pldm_tid_t& tid)
15 {
16     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
17     {
18         return std::nullopt;
19     }
20 
21     if ((!this->transportLayerTable.contains(tid)) ||
22         (this->transportLayerTable[tid] != SupportedTransportLayer::MCTP))
23     {
24         return std::nullopt;
25     }
26 
27     auto mctpInfoIt = mctpInfoTable.find(tid);
28     if (mctpInfoIt == mctpInfoTable.end())
29     {
30         return std::nullopt;
31     }
32 
33     return mctpInfoIt->second;
34 }
35 
36 std::optional<pldm_tid_t> TerminusManager::toTid(const MctpInfo& mctpInfo) const
37 {
38     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
39     {
40         return std::nullopt;
41     }
42 
43     auto mctpInfoTableIt = std::find_if(
44         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
45         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
46                (std::get<3>(v.second) == std::get<3>(mctpInfo));
47     });
48     if (mctpInfoTableIt == mctpInfoTable.end())
49     {
50         return std::nullopt;
51     }
52     return mctpInfoTableIt->first;
53 }
54 
55 std::optional<pldm_tid_t>
56     TerminusManager::storeTerminusInfo(const MctpInfo& mctpInfo, pldm_tid_t tid)
57 {
58     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
59     {
60         return std::nullopt;
61     }
62 
63     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
64     {
65         return std::nullopt;
66     }
67 
68     if (tidPool[tid])
69     {
70         return std::nullopt;
71     }
72 
73     tidPool[tid] = true;
74     transportLayerTable[tid] = SupportedTransportLayer::MCTP;
75     mctpInfoTable[tid] = mctpInfo;
76 
77     return tid;
78 }
79 
80 std::optional<pldm_tid_t> TerminusManager::mapTid(const MctpInfo& mctpInfo)
81 {
82     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
83     {
84         return std::nullopt;
85     }
86 
87     auto mctpInfoTableIt = std::find_if(
88         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
89         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
90                (std::get<3>(v.second) == std::get<3>(mctpInfo));
91     });
92     if (mctpInfoTableIt != mctpInfoTable.end())
93     {
94         return mctpInfoTableIt->first;
95     }
96 
97     auto tidPoolIt = std::find(tidPool.begin(), tidPool.end(), false);
98     if (tidPoolIt == tidPool.end())
99     {
100         return std::nullopt;
101     }
102 
103     pldm_tid_t tid = std::distance(tidPool.begin(), tidPoolIt);
104     return storeTerminusInfo(mctpInfo, tid);
105 }
106 
107 bool TerminusManager::unmapTid(const pldm_tid_t& tid)
108 {
109     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
110     {
111         return false;
112     }
113     tidPool[tid] = false;
114 
115     if (transportLayerTable.contains(tid))
116     {
117         transportLayerTable.erase(tid);
118     }
119 
120     if (mctpInfoTable.contains(tid))
121     {
122         mctpInfoTable.erase(tid);
123     }
124 
125     return true;
126 }
127 
128 void TerminusManager::discoverMctpTerminus(const MctpInfos& mctpInfos)
129 {
130     queuedMctpInfos.emplace(mctpInfos);
131     if (discoverMctpTerminusTaskHandle.has_value())
132     {
133         auto& [scope, rcOpt] = *discoverMctpTerminusTaskHandle;
134         if (!rcOpt.has_value())
135         {
136             return;
137         }
138         stdexec::sync_wait(scope.on_empty());
139         discoverMctpTerminusTaskHandle.reset();
140     }
141     auto& [scope, rcOpt] = discoverMctpTerminusTaskHandle.emplace();
142     scope.spawn(discoverMctpTerminusTask() |
143                     stdexec::then([&](int rc) { rcOpt.emplace(rc); }),
144                 exec::default_task_context<void>(exec::inline_scheduler{}));
145 }
146 
147 TerminiMapper::iterator
148     TerminusManager::findTerminusPtr(const MctpInfo& mctpInfo)
149 {
150     auto foundIter = std::find_if(termini.begin(), termini.end(),
151                                   [&](const auto& terminusPair) {
152         auto terminusMctpInfo = toMctpInfo(terminusPair.first);
153         return (
154             terminusMctpInfo &&
155             (std::get<0>(terminusMctpInfo.value()) == std::get<0>(mctpInfo)) &&
156             (std::get<3>(terminusMctpInfo.value()) == std::get<3>(mctpInfo)));
157     });
158 
159     return foundIter;
160 }
161 
162 exec::task<int> TerminusManager::discoverMctpTerminusTask()
163 {
164     while (!queuedMctpInfos.empty())
165     {
166         if (manager)
167         {
168             co_await manager->beforeDiscoverTerminus();
169         }
170 
171         const MctpInfos& mctpInfos = queuedMctpInfos.front();
172         for (const auto& mctpInfo : mctpInfos)
173         {
174             auto it = findTerminusPtr(mctpInfo);
175             if (it == termini.end())
176             {
177                 co_await initMctpTerminus(mctpInfo);
178             }
179         }
180 
181         if (manager)
182         {
183             co_await manager->afterDiscoverTerminus();
184         }
185 
186         queuedMctpInfos.pop();
187     }
188 
189     co_return PLDM_SUCCESS;
190 }
191 
192 void TerminusManager::removeMctpTerminus(const MctpInfos& mctpInfos)
193 {
194     // remove terminus
195     for (const auto& mctpInfo : mctpInfos)
196     {
197         auto it = findTerminusPtr(mctpInfo);
198         if (it == termini.end())
199         {
200             continue;
201         }
202 
203         unmapTid(it->first);
204         termini.erase(it);
205     }
206 }
207 
208 exec::task<int> TerminusManager::initMctpTerminus(const MctpInfo& mctpInfo)
209 {
210     mctp_eid_t eid = std::get<0>(mctpInfo);
211     pldm_tid_t tid = 0;
212     bool isMapped = false;
213     auto rc = co_await getTidOverMctp(eid, &tid);
214     if (rc != PLDM_SUCCESS)
215     {
216         lg2::error("Failed to Get Terminus ID, error {ERROR}.", "ERROR", rc);
217         co_return PLDM_ERROR;
218     }
219 
220     if (tid == PLDM_TID_RESERVED)
221     {
222         lg2::error("Terminus responses the reserved {TID}.", "TID", tid);
223         co_return PLDM_ERROR;
224     }
225 
226     /* Terminus already has TID */
227     if (tid != PLDM_TID_UNASSIGNED)
228     {
229         /* TID is used by one discovered terminus */
230         auto it = termini.find(tid);
231         if (it != termini.end())
232         {
233             auto terminusMctpInfo = toMctpInfo(it->first);
234             /* The discovered terminus has the same MCTP Info */
235             if (terminusMctpInfo &&
236                 (std::get<0>(terminusMctpInfo.value()) ==
237                  std::get<0>(mctpInfo)) &&
238                 (std::get<3>(terminusMctpInfo.value()) ==
239                  std::get<3>(mctpInfo)))
240             {
241                 co_return PLDM_SUCCESS;
242             }
243             else
244             {
245                 /* ToDo:
246                  * Maybe the terminus supports multiple medium interfaces
247                  * Or the TID is used by other terminus.
248                  * Check the UUID to confirm.
249                  */
250                 isMapped = false;
251             }
252         }
253         /* Use the terminus TID for mapping */
254         else
255         {
256             auto mappedTid = storeTerminusInfo(mctpInfo, tid);
257             if (!mappedTid)
258             {
259                 lg2::error("Failed to store Terminus Info for terminus {TID}.",
260                            "TID", tid);
261                 co_return PLDM_ERROR;
262             }
263             isMapped = true;
264         }
265     }
266 
267     if (!isMapped)
268     {
269         // Assigning a tid. If it has been mapped, mapTid()
270         // returns the tid assigned before.
271         auto mappedTid = mapTid(mctpInfo);
272         if (!mappedTid)
273         {
274             lg2::error("Failed to store Terminus Info for terminus {TID}.",
275                        "TID", tid);
276             co_return PLDM_ERROR;
277         }
278 
279         tid = mappedTid.value();
280         rc = co_await setTidOverMctp(eid, tid);
281         if (rc != PLDM_SUCCESS)
282         {
283             lg2::error("Failed to Set terminus TID, error{ERROR}.", "ERROR",
284                        rc);
285             unmapTid(tid);
286             co_return rc;
287         }
288 
289         if (rc != PLDM_SUCCESS && rc != PLDM_ERROR_UNSUPPORTED_PLDM_CMD)
290         {
291             lg2::error("Terminus {TID} does not support SetTID command.", "TID",
292                        tid);
293             unmapTid(tid);
294             co_return rc;
295         }
296 
297         if (termini.contains(tid))
298         {
299             // the terminus has been discovered before
300             co_return PLDM_SUCCESS;
301         }
302     }
303     /* Discovery the mapped terminus */
304     uint64_t supportedTypes = 0;
305     rc = co_await getPLDMTypes(tid, supportedTypes);
306     if (rc)
307     {
308         lg2::error("Failed to Get PLDM Types for terminus {TID}, error {ERROR}",
309                    "TID", tid, "ERROR", rc);
310         co_return PLDM_ERROR;
311     }
312 
313     try
314     {
315         termini[tid] = std::make_shared<Terminus>(tid, supportedTypes);
316     }
317     catch (const sdbusplus::exception_t& e)
318     {
319         lg2::error("Failed to create terminus manager for terminus {TID}",
320                    "TID", tid);
321         co_return PLDM_ERROR;
322     }
323 
324     uint8_t type = PLDM_BASE;
325     auto size = PLDM_MAX_TYPES * (PLDM_MAX_CMDS_PER_TYPE / 8);
326     std::vector<uint8_t> pldmCmds(size);
327     while ((type < PLDM_MAX_TYPES))
328     {
329         if (!termini[tid]->doesSupportType(type))
330         {
331             type++;
332             continue;
333         }
334         std::vector<bitfield8_t> cmds(PLDM_MAX_CMDS_PER_TYPE / 8);
335         auto rc = co_await getPLDMCommands(tid, type, cmds.data());
336         if (rc)
337         {
338             lg2::error(
339                 "Failed to Get PLDM Commands for terminus {TID}, error {ERROR}",
340                 "TID", tid, "ERROR", rc);
341         }
342 
343         for (size_t i = 0; i < cmds.size(); i++)
344         {
345             auto idx = type * (PLDM_MAX_CMDS_PER_TYPE / 8) + i;
346             if (idx >= pldmCmds.size())
347             {
348                 lg2::error(
349                     "Calculated index {IDX} out of bounds for pldmCmds, type {TYPE}, command index {CMD_IDX}",
350                     "IDX", idx, "TYPE", type, "CMD_IDX", i);
351                 continue;
352             }
353             pldmCmds[idx] = cmds[i].byte;
354         }
355         type++;
356     }
357     termini[tid]->setSupportedCommands(pldmCmds);
358 
359     co_return PLDM_SUCCESS;
360 }
361 
362 exec::task<int>
363     TerminusManager::sendRecvPldmMsgOverMctp(mctp_eid_t eid, Request& request,
364                                              const pldm_msg** responseMsg,
365                                              size_t* responseLen)
366 {
367     int rc = 0;
368     try
369     {
370         std::tie(rc, *responseMsg, *responseLen) =
371             co_await handler.sendRecvMsg(eid, std::move(request));
372     }
373     catch (const sdbusplus::exception_t& e)
374     {
375         lg2::error(
376             "Send and Receive PLDM message over MCTP throw error - {ERROR}.",
377             "ERROR", e);
378         co_return PLDM_ERROR;
379     }
380     catch (const int& e)
381     {
382         lg2::error(
383             "Send and Receive PLDM message over MCTP throw int error - {ERROR}.",
384             "ERROR", e);
385         co_return PLDM_ERROR;
386     }
387 
388     co_return rc;
389 }
390 
391 exec::task<int> TerminusManager::getTidOverMctp(mctp_eid_t eid, pldm_tid_t* tid)
392 {
393     auto instanceId = instanceIdDb.next(eid);
394     Request request(sizeof(pldm_msg_hdr));
395     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
396     auto rc = encode_get_tid_req(instanceId, requestMsg);
397     if (rc)
398     {
399         instanceIdDb.free(eid, instanceId);
400         lg2::error(
401             "Failed to encode request GetTID for endpoint ID {EID}, error {RC} ",
402             "EID", eid, "RC", rc);
403         co_return rc;
404     }
405 
406     const pldm_msg* responseMsg = nullptr;
407     size_t responseLen = 0;
408     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
409                                           &responseLen);
410     if (rc)
411     {
412         lg2::error("Failed to send GetTID for Endpoint {EID}, error {RC}",
413                    "EID", eid, "RC", rc);
414         co_return rc;
415     }
416 
417     uint8_t completionCode = 0;
418     rc = decode_get_tid_resp(responseMsg, responseLen, &completionCode, tid);
419     if (rc)
420     {
421         lg2::error(
422             "Failed to decode response GetTID for Endpoint ID {EID}, error {RC} ",
423             "EID", eid, "RC", rc);
424         co_return rc;
425     }
426 
427     if (completionCode != PLDM_SUCCESS)
428     {
429         lg2::error("Error : GetTID for Endpoint ID {EID}, complete code {CC}.",
430                    "EID", eid, "CC", completionCode);
431         co_return rc;
432     }
433 
434     co_return completionCode;
435 }
436 
437 exec::task<int> TerminusManager::setTidOverMctp(mctp_eid_t eid, pldm_tid_t tid)
438 {
439     auto instanceId = instanceIdDb.next(eid);
440     Request request(sizeof(pldm_msg_hdr) + sizeof(pldm_set_tid_req));
441     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
442     auto rc = encode_set_tid_req(instanceId, tid, requestMsg);
443     if (rc)
444     {
445         instanceIdDb.free(eid, instanceId);
446         lg2::error(
447             "Failed to encode request SetTID for endpoint ID {EID}, error {RC} ",
448             "EID", eid, "RC", rc);
449         co_return rc;
450     }
451 
452     const pldm_msg* responseMsg = nullptr;
453     size_t responseLen = 0;
454     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
455                                           &responseLen);
456     if (rc)
457     {
458         lg2::error("Failed to send SetTID for Endpoint {EID}, error {RC}",
459                    "EID", eid, "RC", rc);
460         co_return rc;
461     }
462 
463     if (responseMsg == NULL || responseLen != PLDM_SET_TID_RESP_BYTES)
464     {
465         lg2::error(
466             "Failed to decode response SetTID for Endpoint ID {EID}, error {RC} ",
467             "EID", eid, "RC", rc);
468         co_return PLDM_ERROR_INVALID_LENGTH;
469     }
470 
471     co_return responseMsg->payload[0];
472 }
473 
474 exec::task<int> TerminusManager::getPLDMTypes(pldm_tid_t tid,
475                                               uint64_t& supportedTypes)
476 {
477     Request request(sizeof(pldm_msg_hdr));
478     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
479     auto rc = encode_get_types_req(0, requestMsg);
480     if (rc)
481     {
482         lg2::error(
483             "Failed to encode request getPLDMTypes for terminus ID {TID}, error {RC} ",
484             "TID", tid, "RC", rc);
485         co_return rc;
486     }
487 
488     const pldm_msg* responseMsg = nullptr;
489     size_t responseLen = 0;
490 
491     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
492     if (rc)
493     {
494         lg2::error("Failed to send GetPLDMTypes for terminus {TID}, error {RC}",
495                    "TID", tid, "RC", rc);
496         co_return rc;
497     }
498 
499     uint8_t completionCode = 0;
500     bitfield8_t* types = reinterpret_cast<bitfield8_t*>(&supportedTypes);
501     rc = decode_get_types_resp(responseMsg, responseLen, &completionCode,
502                                types);
503     if (rc)
504     {
505         lg2::error(
506             "Failed to decode response GetPLDMTypes for terminus ID {TID}, error {RC} ",
507             "TID", tid, "RC", rc);
508         co_return rc;
509     }
510 
511     if (completionCode != PLDM_SUCCESS)
512     {
513         lg2::error(
514             "Error : GetPLDMTypes for terminus ID {TID}, complete code {CC}.",
515             "TID", tid, "CC", completionCode);
516         co_return rc;
517     }
518     co_return completionCode;
519 }
520 
521 exec::task<int> TerminusManager::getPLDMCommands(pldm_tid_t tid, uint8_t type,
522                                                  bitfield8_t* supportedCmds)
523 {
524     Request request(sizeof(pldm_msg_hdr) + PLDM_GET_COMMANDS_REQ_BYTES);
525     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
526     ver32_t version{0xFF, 0xFF, 0xFF, 0xFF};
527 
528     auto rc = encode_get_commands_req(0, type, version, requestMsg);
529     if (rc)
530     {
531         lg2::error(
532             "Failed to encode request GetPLDMCommands for terminus ID {TID}, error {RC} ",
533             "TID", tid, "RC", rc);
534         co_return rc;
535     }
536 
537     const pldm_msg* responseMsg = nullptr;
538     size_t responseLen = 0;
539 
540     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
541     if (rc)
542     {
543         lg2::error(
544             "Failed to send GetPLDMCommands message for terminus {TID}, error {RC}",
545             "TID", tid, "RC", rc);
546         co_return rc;
547     }
548 
549     /* Process response */
550     uint8_t completionCode = 0;
551     rc = decode_get_commands_resp(responseMsg, responseLen, &completionCode,
552                                   supportedCmds);
553     if (rc)
554     {
555         lg2::error(
556             "Failed to decode response GetPLDMCommands for terminus ID {TID}, error {RC} ",
557             "TID", tid, "RC", rc);
558         co_return rc;
559     }
560 
561     if (completionCode != PLDM_SUCCESS)
562     {
563         lg2::error(
564             "Error : GetPLDMCommands for terminus ID {TID}, complete code {CC}.",
565             "TID", tid, "CC", completionCode);
566         co_return rc;
567     }
568 
569     co_return completionCode;
570 }
571 
572 exec::task<int> TerminusManager::sendRecvPldmMsg(pldm_tid_t tid,
573                                                  Request& request,
574                                                  const pldm_msg** responseMsg,
575                                                  size_t* responseLen)
576 {
577     /**
578      * Size of tidPool is `std::numeric_limits<pldm_tid_t>::max() + 1`
579      * tidPool[i] always exist
580      */
581     if (!tidPool[tid])
582     {
583         co_return PLDM_ERROR_NOT_READY;
584     }
585 
586     if (!transportLayerTable.contains(tid))
587     {
588         co_return PLDM_ERROR_NOT_READY;
589     }
590 
591     if (transportLayerTable[tid] != SupportedTransportLayer::MCTP)
592     {
593         co_return PLDM_ERROR_NOT_READY;
594     }
595 
596     auto mctpInfo = toMctpInfo(tid);
597     if (!mctpInfo.has_value())
598     {
599         co_return PLDM_ERROR_NOT_READY;
600     }
601 
602     auto eid = std::get<0>(mctpInfo.value());
603     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
604     requestMsg->hdr.instance_id = instanceIdDb.next(eid);
605     auto rc = co_await sendRecvPldmMsgOverMctp(eid, request, responseMsg,
606                                                responseLen);
607 
608     co_return rc;
609 }
610 
611 } // namespace platform_mc
612 } // namespace pldm
613