1 #include "terminus_manager.hpp"
2 
3 #include "manager.hpp"
4 
5 #include <phosphor-logging/lg2.hpp>
6 
7 PHOSPHOR_LOG2_USING;
8 
9 namespace pldm
10 {
11 namespace platform_mc
12 {
13 
14 std::optional<MctpInfo> TerminusManager::toMctpInfo(const pldm_tid_t& tid)
15 {
16     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
17     {
18         return std::nullopt;
19     }
20 
21     if ((!this->transportLayerTable.contains(tid)) ||
22         (this->transportLayerTable[tid] != SupportedTransportLayer::MCTP))
23     {
24         return std::nullopt;
25     }
26 
27     auto mctpInfoIt = mctpInfoTable.find(tid);
28     if (mctpInfoIt == mctpInfoTable.end())
29     {
30         return std::nullopt;
31     }
32 
33     return mctpInfoIt->second;
34 }
35 
36 std::optional<pldm_tid_t> TerminusManager::toTid(const MctpInfo& mctpInfo) const
37 {
38     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
39     {
40         return std::nullopt;
41     }
42 
43     auto mctpInfoTableIt = std::find_if(
44         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
45             return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
46                    (std::get<3>(v.second) == std::get<3>(mctpInfo));
47         });
48     if (mctpInfoTableIt == mctpInfoTable.end())
49     {
50         return std::nullopt;
51     }
52     return mctpInfoTableIt->first;
53 }
54 
55 std::optional<pldm_tid_t>
56     TerminusManager::storeTerminusInfo(const MctpInfo& mctpInfo, pldm_tid_t tid)
57 {
58     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
59     {
60         return std::nullopt;
61     }
62 
63     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
64     {
65         return std::nullopt;
66     }
67 
68     if (tidPool[tid])
69     {
70         return std::nullopt;
71     }
72 
73     tidPool[tid] = true;
74     transportLayerTable[tid] = SupportedTransportLayer::MCTP;
75     mctpInfoTable[tid] = mctpInfo;
76 
77     return tid;
78 }
79 
80 std::optional<pldm_tid_t> TerminusManager::mapTid(const MctpInfo& mctpInfo)
81 {
82     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
83     {
84         return std::nullopt;
85     }
86 
87     auto mctpInfoTableIt = std::find_if(
88         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
89             return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
90                    (std::get<3>(v.second) == std::get<3>(mctpInfo));
91         });
92     if (mctpInfoTableIt != mctpInfoTable.end())
93     {
94         return mctpInfoTableIt->first;
95     }
96 
97     auto tidPoolIt = std::find(tidPool.begin(), tidPool.end(), false);
98     if (tidPoolIt == tidPool.end())
99     {
100         return std::nullopt;
101     }
102 
103     pldm_tid_t tid = std::distance(tidPool.begin(), tidPoolIt);
104     return storeTerminusInfo(mctpInfo, tid);
105 }
106 
107 bool TerminusManager::unmapTid(const pldm_tid_t& tid)
108 {
109     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
110     {
111         return false;
112     }
113     tidPool[tid] = false;
114 
115     if (transportLayerTable.contains(tid))
116     {
117         transportLayerTable.erase(tid);
118     }
119 
120     if (mctpInfoTable.contains(tid))
121     {
122         mctpInfoTable.erase(tid);
123     }
124 
125     return true;
126 }
127 
128 void TerminusManager::discoverMctpTerminus(const MctpInfos& mctpInfos)
129 {
130     queuedMctpInfos.emplace(mctpInfos);
131     if (discoverMctpTerminusTaskHandle.has_value())
132     {
133         auto& [scope, rcOpt] = *discoverMctpTerminusTaskHandle;
134         if (!rcOpt.has_value())
135         {
136             return;
137         }
138         stdexec::sync_wait(scope.on_empty());
139         discoverMctpTerminusTaskHandle.reset();
140     }
141     auto& [scope, rcOpt] = discoverMctpTerminusTaskHandle.emplace();
142     scope.spawn(discoverMctpTerminusTask() |
143                     stdexec::then([&](int rc) { rcOpt.emplace(rc); }),
144                 exec::default_task_context<void>(exec::inline_scheduler{}));
145 }
146 
147 TerminiMapper::iterator
148     TerminusManager::findTerminusPtr(const MctpInfo& mctpInfo)
149 {
150     auto foundIter = std::find_if(
151         termini.begin(), termini.end(), [&](const auto& terminusPair) {
152             auto terminusMctpInfo = toMctpInfo(terminusPair.first);
153             return (terminusMctpInfo &&
154                     (std::get<0>(terminusMctpInfo.value()) ==
155                      std::get<0>(mctpInfo)) &&
156                     (std::get<3>(terminusMctpInfo.value()) ==
157                      std::get<3>(mctpInfo)));
158         });
159 
160     return foundIter;
161 }
162 
163 exec::task<int> TerminusManager::discoverMctpTerminusTask()
164 {
165     std::vector<pldm_tid_t> addedTids;
166     while (!queuedMctpInfos.empty())
167     {
168         if (manager)
169         {
170             co_await manager->beforeDiscoverTerminus();
171         }
172 
173         const MctpInfos& mctpInfos = queuedMctpInfos.front();
174         for (const auto& mctpInfo : mctpInfos)
175         {
176             auto it = findTerminusPtr(mctpInfo);
177             if (it == termini.end())
178             {
179                 co_await initMctpTerminus(mctpInfo);
180             }
181 
182             /* Get TID of initialized terminus */
183             auto tid = toTid(mctpInfo);
184             if (!tid)
185             {
186                 co_return PLDM_ERROR;
187             }
188             addedTids.push_back(tid.value());
189         }
190 
191         if (manager)
192         {
193             co_await manager->afterDiscoverTerminus();
194             for (const auto& tid : addedTids)
195             {
196                 manager->startSensorPolling(tid);
197             }
198         }
199 
200         queuedMctpInfos.pop();
201     }
202 
203     co_return PLDM_SUCCESS;
204 }
205 
206 void TerminusManager::removeMctpTerminus(const MctpInfos& mctpInfos)
207 {
208     // remove terminus
209     for (const auto& mctpInfo : mctpInfos)
210     {
211         auto it = findTerminusPtr(mctpInfo);
212         if (it == termini.end())
213         {
214             continue;
215         }
216 
217         if (manager)
218         {
219             manager->stopSensorPolling(it->second->getTid());
220         }
221 
222         unmapTid(it->first);
223         termini.erase(it);
224     }
225 }
226 
227 exec::task<int> TerminusManager::initMctpTerminus(const MctpInfo& mctpInfo)
228 {
229     mctp_eid_t eid = std::get<0>(mctpInfo);
230     pldm_tid_t tid = 0;
231     bool isMapped = false;
232     auto rc = co_await getTidOverMctp(eid, &tid);
233     if (rc != PLDM_SUCCESS)
234     {
235         lg2::error("Failed to Get Terminus ID, error {ERROR}.", "ERROR", rc);
236         co_return PLDM_ERROR;
237     }
238 
239     if (tid == PLDM_TID_RESERVED)
240     {
241         lg2::error("Terminus responses the reserved {TID}.", "TID", tid);
242         co_return PLDM_ERROR;
243     }
244 
245     /* Terminus already has TID */
246     if (tid != PLDM_TID_UNASSIGNED)
247     {
248         /* TID is used by one discovered terminus */
249         auto it = termini.find(tid);
250         if (it != termini.end())
251         {
252             auto terminusMctpInfo = toMctpInfo(it->first);
253             /* The discovered terminus has the same MCTP Info */
254             if (terminusMctpInfo &&
255                 (std::get<0>(terminusMctpInfo.value()) ==
256                  std::get<0>(mctpInfo)) &&
257                 (std::get<3>(terminusMctpInfo.value()) ==
258                  std::get<3>(mctpInfo)))
259             {
260                 co_return PLDM_SUCCESS;
261             }
262             else
263             {
264                 /* ToDo:
265                  * Maybe the terminus supports multiple medium interfaces
266                  * Or the TID is used by other terminus.
267                  * Check the UUID to confirm.
268                  */
269                 isMapped = false;
270             }
271         }
272         /* Use the terminus TID for mapping */
273         else
274         {
275             auto mappedTid = storeTerminusInfo(mctpInfo, tid);
276             if (!mappedTid)
277             {
278                 lg2::error("Failed to store Terminus Info for terminus {TID}.",
279                            "TID", tid);
280                 co_return PLDM_ERROR;
281             }
282             isMapped = true;
283         }
284     }
285 
286     if (!isMapped)
287     {
288         // Assigning a tid. If it has been mapped, mapTid()
289         // returns the tid assigned before.
290         auto mappedTid = mapTid(mctpInfo);
291         if (!mappedTid)
292         {
293             lg2::error("Failed to store Terminus Info for terminus {TID}.",
294                        "TID", tid);
295             co_return PLDM_ERROR;
296         }
297 
298         tid = mappedTid.value();
299         rc = co_await setTidOverMctp(eid, tid);
300         if (rc != PLDM_SUCCESS)
301         {
302             lg2::error("Failed to Set terminus TID, error{ERROR}.", "ERROR",
303                        rc);
304             unmapTid(tid);
305             co_return rc;
306         }
307 
308         if (rc != PLDM_SUCCESS && rc != PLDM_ERROR_UNSUPPORTED_PLDM_CMD)
309         {
310             lg2::error("Terminus {TID} does not support SetTID command.", "TID",
311                        tid);
312             unmapTid(tid);
313             co_return rc;
314         }
315 
316         if (termini.contains(tid))
317         {
318             // the terminus has been discovered before
319             co_return PLDM_SUCCESS;
320         }
321     }
322     /* Discovery the mapped terminus */
323     uint64_t supportedTypes = 0;
324     rc = co_await getPLDMTypes(tid, supportedTypes);
325     if (rc)
326     {
327         lg2::error("Failed to Get PLDM Types for terminus {TID}, error {ERROR}",
328                    "TID", tid, "ERROR", rc);
329         unmapTid(tid);
330         co_return PLDM_ERROR;
331     }
332 
333     try
334     {
335         termini[tid] = std::make_shared<Terminus>(tid, supportedTypes);
336     }
337     catch (const sdbusplus::exception_t& e)
338     {
339         lg2::error("Failed to create terminus manager for terminus {TID}",
340                    "TID", tid);
341         unmapTid(tid);
342         co_return PLDM_ERROR;
343     }
344 
345     uint8_t type = PLDM_BASE;
346     auto size = PLDM_MAX_TYPES * (PLDM_MAX_CMDS_PER_TYPE / 8);
347     std::vector<uint8_t> pldmCmds(size);
348     while ((type < PLDM_MAX_TYPES))
349     {
350         if (!termini[tid]->doesSupportType(type))
351         {
352             type++;
353             continue;
354         }
355         std::vector<bitfield8_t> cmds(PLDM_MAX_CMDS_PER_TYPE / 8);
356         auto rc = co_await getPLDMCommands(tid, type, cmds.data());
357         if (rc)
358         {
359             lg2::error(
360                 "Failed to Get PLDM Commands for terminus {TID}, error {ERROR}",
361                 "TID", tid, "ERROR", rc);
362         }
363 
364         for (size_t i = 0; i < cmds.size(); i++)
365         {
366             auto idx = type * (PLDM_MAX_CMDS_PER_TYPE / 8) + i;
367             if (idx >= pldmCmds.size())
368             {
369                 lg2::error(
370                     "Calculated index {IDX} out of bounds for pldmCmds, type {TYPE}, command index {CMD_IDX}",
371                     "IDX", idx, "TYPE", type, "CMD_IDX", i);
372                 continue;
373             }
374             pldmCmds[idx] = cmds[i].byte;
375         }
376         type++;
377     }
378     termini[tid]->setSupportedCommands(pldmCmds);
379 
380     co_return PLDM_SUCCESS;
381 }
382 
383 exec::task<int> TerminusManager::sendRecvPldmMsgOverMctp(
384     mctp_eid_t eid, Request& request, const pldm_msg** responseMsg,
385     size_t* responseLen)
386 {
387     int rc = 0;
388     try
389     {
390         std::tie(rc, *responseMsg, *responseLen) =
391             co_await handler.sendRecvMsg(eid, std::move(request));
392     }
393     catch (const sdbusplus::exception_t& e)
394     {
395         lg2::error(
396             "Send and Receive PLDM message over MCTP throw error - {ERROR}.",
397             "ERROR", e);
398         co_return PLDM_ERROR;
399     }
400     catch (const int& e)
401     {
402         lg2::error(
403             "Send and Receive PLDM message over MCTP throw int error - {ERROR}.",
404             "ERROR", e);
405         co_return PLDM_ERROR;
406     }
407 
408     co_return rc;
409 }
410 
411 exec::task<int> TerminusManager::getTidOverMctp(mctp_eid_t eid, pldm_tid_t* tid)
412 {
413     auto instanceId = instanceIdDb.next(eid);
414     Request request(sizeof(pldm_msg_hdr));
415     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
416     auto rc = encode_get_tid_req(instanceId, requestMsg);
417     if (rc)
418     {
419         instanceIdDb.free(eid, instanceId);
420         lg2::error(
421             "Failed to encode request GetTID for endpoint ID {EID}, error {RC} ",
422             "EID", eid, "RC", rc);
423         co_return rc;
424     }
425 
426     const pldm_msg* responseMsg = nullptr;
427     size_t responseLen = 0;
428     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
429                                           &responseLen);
430     if (rc)
431     {
432         lg2::error("Failed to send GetTID for Endpoint {EID}, error {RC}",
433                    "EID", eid, "RC", rc);
434         co_return rc;
435     }
436 
437     uint8_t completionCode = 0;
438     rc = decode_get_tid_resp(responseMsg, responseLen, &completionCode, tid);
439     if (rc)
440     {
441         lg2::error(
442             "Failed to decode response GetTID for Endpoint ID {EID}, error {RC} ",
443             "EID", eid, "RC", rc);
444         co_return rc;
445     }
446 
447     if (completionCode != PLDM_SUCCESS)
448     {
449         lg2::error("Error : GetTID for Endpoint ID {EID}, complete code {CC}.",
450                    "EID", eid, "CC", completionCode);
451         co_return rc;
452     }
453 
454     co_return completionCode;
455 }
456 
457 exec::task<int> TerminusManager::setTidOverMctp(mctp_eid_t eid, pldm_tid_t tid)
458 {
459     auto instanceId = instanceIdDb.next(eid);
460     Request request(sizeof(pldm_msg_hdr) + sizeof(pldm_set_tid_req));
461     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
462     auto rc = encode_set_tid_req(instanceId, tid, requestMsg);
463     if (rc)
464     {
465         instanceIdDb.free(eid, instanceId);
466         lg2::error(
467             "Failed to encode request SetTID for endpoint ID {EID}, error {RC} ",
468             "EID", eid, "RC", rc);
469         co_return rc;
470     }
471 
472     const pldm_msg* responseMsg = nullptr;
473     size_t responseLen = 0;
474     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
475                                           &responseLen);
476     if (rc)
477     {
478         lg2::error("Failed to send SetTID for Endpoint {EID}, error {RC}",
479                    "EID", eid, "RC", rc);
480         co_return rc;
481     }
482 
483     if (responseMsg == NULL || responseLen != PLDM_SET_TID_RESP_BYTES)
484     {
485         lg2::error(
486             "Failed to decode response SetTID for Endpoint ID {EID}, error {RC} ",
487             "EID", eid, "RC", rc);
488         co_return PLDM_ERROR_INVALID_LENGTH;
489     }
490 
491     co_return responseMsg->payload[0];
492 }
493 
494 exec::task<int>
495     TerminusManager::getPLDMTypes(pldm_tid_t tid, uint64_t& supportedTypes)
496 {
497     Request request(sizeof(pldm_msg_hdr));
498     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
499     auto rc = encode_get_types_req(0, requestMsg);
500     if (rc)
501     {
502         lg2::error(
503             "Failed to encode request getPLDMTypes for terminus ID {TID}, error {RC} ",
504             "TID", tid, "RC", rc);
505         co_return rc;
506     }
507 
508     const pldm_msg* responseMsg = nullptr;
509     size_t responseLen = 0;
510 
511     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
512     if (rc)
513     {
514         lg2::error("Failed to send GetPLDMTypes for terminus {TID}, error {RC}",
515                    "TID", tid, "RC", rc);
516         co_return rc;
517     }
518 
519     uint8_t completionCode = 0;
520     bitfield8_t* types = reinterpret_cast<bitfield8_t*>(&supportedTypes);
521     rc =
522         decode_get_types_resp(responseMsg, responseLen, &completionCode, types);
523     if (rc)
524     {
525         lg2::error(
526             "Failed to decode response GetPLDMTypes for terminus ID {TID}, error {RC} ",
527             "TID", tid, "RC", rc);
528         co_return rc;
529     }
530 
531     if (completionCode != PLDM_SUCCESS)
532     {
533         lg2::error(
534             "Error : GetPLDMTypes for terminus ID {TID}, complete code {CC}.",
535             "TID", tid, "CC", completionCode);
536         co_return rc;
537     }
538     co_return completionCode;
539 }
540 
541 exec::task<int> TerminusManager::getPLDMCommands(pldm_tid_t tid, uint8_t type,
542                                                  bitfield8_t* supportedCmds)
543 {
544     Request request(sizeof(pldm_msg_hdr) + PLDM_GET_COMMANDS_REQ_BYTES);
545     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
546     ver32_t version{0xFF, 0xFF, 0xFF, 0xFF};
547 
548     auto rc = encode_get_commands_req(0, type, version, requestMsg);
549     if (rc)
550     {
551         lg2::error(
552             "Failed to encode request GetPLDMCommands for terminus ID {TID}, error {RC} ",
553             "TID", tid, "RC", rc);
554         co_return rc;
555     }
556 
557     const pldm_msg* responseMsg = nullptr;
558     size_t responseLen = 0;
559 
560     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
561     if (rc)
562     {
563         lg2::error(
564             "Failed to send GetPLDMCommands message for terminus {TID}, error {RC}",
565             "TID", tid, "RC", rc);
566         co_return rc;
567     }
568 
569     /* Process response */
570     uint8_t completionCode = 0;
571     rc = decode_get_commands_resp(responseMsg, responseLen, &completionCode,
572                                   supportedCmds);
573     if (rc)
574     {
575         lg2::error(
576             "Failed to decode response GetPLDMCommands for terminus ID {TID}, error {RC} ",
577             "TID", tid, "RC", rc);
578         co_return rc;
579     }
580 
581     if (completionCode != PLDM_SUCCESS)
582     {
583         lg2::error(
584             "Error : GetPLDMCommands for terminus ID {TID}, complete code {CC}.",
585             "TID", tid, "CC", completionCode);
586         co_return rc;
587     }
588 
589     co_return completionCode;
590 }
591 
592 exec::task<int> TerminusManager::sendRecvPldmMsg(
593     pldm_tid_t tid, Request& request, const pldm_msg** responseMsg,
594     size_t* responseLen)
595 {
596     /**
597      * Size of tidPool is `std::numeric_limits<pldm_tid_t>::max() + 1`
598      * tidPool[i] always exist
599      */
600     if (!tidPool[tid])
601     {
602         co_return PLDM_ERROR_NOT_READY;
603     }
604 
605     if (!transportLayerTable.contains(tid))
606     {
607         co_return PLDM_ERROR_NOT_READY;
608     }
609 
610     if (transportLayerTable[tid] != SupportedTransportLayer::MCTP)
611     {
612         co_return PLDM_ERROR_NOT_READY;
613     }
614 
615     auto mctpInfo = toMctpInfo(tid);
616     if (!mctpInfo.has_value())
617     {
618         co_return PLDM_ERROR_NOT_READY;
619     }
620 
621     auto eid = std::get<0>(mctpInfo.value());
622     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
623     requestMsg->hdr.instance_id = instanceIdDb.next(eid);
624     auto rc = co_await sendRecvPldmMsgOverMctp(eid, request, responseMsg,
625                                                responseLen);
626 
627     co_return rc;
628 }
629 
630 } // namespace platform_mc
631 } // namespace pldm
632