1 #include "terminus_manager.hpp"
2 
3 #include "manager.hpp"
4 
5 #include <phosphor-logging/lg2.hpp>
6 
7 PHOSPHOR_LOG2_USING;
8 
9 namespace pldm
10 {
11 namespace platform_mc
12 {
13 
14 std::optional<MctpInfo> TerminusManager::toMctpInfo(const pldm_tid_t& tid)
15 {
16     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
17     {
18         return std::nullopt;
19     }
20 
21     if ((!this->transportLayerTable.contains(tid)) ||
22         (this->transportLayerTable[tid] != SupportedTransportLayer::MCTP))
23     {
24         return std::nullopt;
25     }
26 
27     auto mctpInfoIt = mctpInfoTable.find(tid);
28     if (mctpInfoIt == mctpInfoTable.end())
29     {
30         return std::nullopt;
31     }
32 
33     return mctpInfoIt->second;
34 }
35 
36 std::optional<pldm_tid_t> TerminusManager::toTid(const MctpInfo& mctpInfo) const
37 {
38     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
39     {
40         return std::nullopt;
41     }
42 
43     auto mctpInfoTableIt = std::find_if(
44         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
45             return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
46                    (std::get<3>(v.second) == std::get<3>(mctpInfo));
47         });
48     if (mctpInfoTableIt == mctpInfoTable.end())
49     {
50         return std::nullopt;
51     }
52     return mctpInfoTableIt->first;
53 }
54 
55 std::optional<pldm_tid_t>
56     TerminusManager::storeTerminusInfo(const MctpInfo& mctpInfo, pldm_tid_t tid)
57 {
58     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
59     {
60         return std::nullopt;
61     }
62 
63     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
64     {
65         return std::nullopt;
66     }
67 
68     if (tidPool[tid])
69     {
70         return std::nullopt;
71     }
72 
73     tidPool[tid] = true;
74     transportLayerTable[tid] = SupportedTransportLayer::MCTP;
75     mctpInfoTable[tid] = mctpInfo;
76 
77     return tid;
78 }
79 
80 std::optional<pldm_tid_t> TerminusManager::mapTid(const MctpInfo& mctpInfo)
81 {
82     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
83     {
84         return std::nullopt;
85     }
86 
87     auto mctpInfoTableIt = std::find_if(
88         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
89             return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
90                    (std::get<3>(v.second) == std::get<3>(mctpInfo));
91         });
92     if (mctpInfoTableIt != mctpInfoTable.end())
93     {
94         return mctpInfoTableIt->first;
95     }
96 
97     auto tidPoolIt = std::find(tidPool.begin(), tidPool.end(), false);
98     if (tidPoolIt == tidPool.end())
99     {
100         return std::nullopt;
101     }
102 
103     pldm_tid_t tid = std::distance(tidPool.begin(), tidPoolIt);
104     return storeTerminusInfo(mctpInfo, tid);
105 }
106 
107 bool TerminusManager::unmapTid(const pldm_tid_t& tid)
108 {
109     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
110     {
111         return false;
112     }
113     tidPool[tid] = false;
114 
115     if (transportLayerTable.contains(tid))
116     {
117         transportLayerTable.erase(tid);
118     }
119 
120     if (mctpInfoTable.contains(tid))
121     {
122         mctpInfoTable.erase(tid);
123     }
124 
125     return true;
126 }
127 
128 void TerminusManager::discoverMctpTerminus(const MctpInfos& mctpInfos)
129 {
130     queuedMctpInfos.emplace(mctpInfos);
131     if (discoverMctpTerminusTaskHandle.has_value())
132     {
133         auto& [scope, rcOpt] = *discoverMctpTerminusTaskHandle;
134         if (!rcOpt.has_value())
135         {
136             return;
137         }
138         stdexec::sync_wait(scope.on_empty());
139         discoverMctpTerminusTaskHandle.reset();
140     }
141     auto& [scope, rcOpt] = discoverMctpTerminusTaskHandle.emplace();
142     scope.spawn(discoverMctpTerminusTask() |
143                     stdexec::then([&](int rc) { rcOpt.emplace(rc); }),
144                 exec::default_task_context<void>(exec::inline_scheduler{}));
145 }
146 
147 TerminiMapper::iterator
148     TerminusManager::findTerminusPtr(const MctpInfo& mctpInfo)
149 {
150     auto foundIter = std::find_if(
151         termini.begin(), termini.end(), [&](const auto& terminusPair) {
152             auto terminusMctpInfo = toMctpInfo(terminusPair.first);
153             return (terminusMctpInfo &&
154                     (std::get<0>(terminusMctpInfo.value()) ==
155                      std::get<0>(mctpInfo)) &&
156                     (std::get<3>(terminusMctpInfo.value()) ==
157                      std::get<3>(mctpInfo)));
158         });
159 
160     return foundIter;
161 }
162 
163 exec::task<int> TerminusManager::discoverMctpTerminusTask()
164 {
165     std::vector<pldm_tid_t> addedTids;
166     while (!queuedMctpInfos.empty())
167     {
168         if (manager)
169         {
170             co_await manager->beforeDiscoverTerminus();
171         }
172 
173         const MctpInfos& mctpInfos = queuedMctpInfos.front();
174         for (const auto& mctpInfo : mctpInfos)
175         {
176             auto it = findTerminusPtr(mctpInfo);
177             if (it == termini.end())
178             {
179                 co_await initMctpTerminus(mctpInfo);
180             }
181 
182             /* Get TID of initialized terminus */
183             auto tid = toTid(mctpInfo);
184             if (!tid)
185             {
186                 co_return PLDM_ERROR;
187             }
188             addedTids.push_back(tid.value());
189         }
190 
191         if (manager)
192         {
193             co_await manager->afterDiscoverTerminus();
194             for (const auto& tid : addedTids)
195             {
196                 manager->startSensorPolling(tid);
197             }
198         }
199 
200         queuedMctpInfos.pop();
201     }
202 
203     co_return PLDM_SUCCESS;
204 }
205 
206 void TerminusManager::removeMctpTerminus(const MctpInfos& mctpInfos)
207 {
208     // remove terminus
209     for (const auto& mctpInfo : mctpInfos)
210     {
211         auto it = findTerminusPtr(mctpInfo);
212         if (it == termini.end())
213         {
214             continue;
215         }
216 
217         if (manager)
218         {
219             manager->stopSensorPolling(it->second->getTid());
220         }
221 
222         unmapTid(it->first);
223         termini.erase(it);
224     }
225 }
226 
227 exec::task<int> TerminusManager::initMctpTerminus(const MctpInfo& mctpInfo)
228 {
229     mctp_eid_t eid = std::get<0>(mctpInfo);
230     pldm_tid_t tid = 0;
231     bool isMapped = false;
232     auto rc = co_await getTidOverMctp(eid, &tid);
233     if (rc != PLDM_SUCCESS)
234     {
235         lg2::error("Failed to Get Terminus ID, error {ERROR}.", "ERROR", rc);
236         co_return PLDM_ERROR;
237     }
238 
239     if (tid == PLDM_TID_RESERVED)
240     {
241         lg2::error("Terminus responses the reserved {TID}.", "TID", tid);
242         co_return PLDM_ERROR;
243     }
244 
245     /* Terminus already has TID */
246     if (tid != PLDM_TID_UNASSIGNED)
247     {
248         /* TID is used by one discovered terminus */
249         auto it = termini.find(tid);
250         if (it != termini.end())
251         {
252             auto terminusMctpInfo = toMctpInfo(it->first);
253             /* The discovered terminus has the same MCTP Info */
254             if (terminusMctpInfo &&
255                 (std::get<0>(terminusMctpInfo.value()) ==
256                  std::get<0>(mctpInfo)) &&
257                 (std::get<3>(terminusMctpInfo.value()) ==
258                  std::get<3>(mctpInfo)))
259             {
260                 co_return PLDM_SUCCESS;
261             }
262             else
263             {
264                 /* ToDo:
265                  * Maybe the terminus supports multiple medium interfaces
266                  * Or the TID is used by other terminus.
267                  * Check the UUID to confirm.
268                  */
269                 isMapped = false;
270             }
271         }
272         /* Use the terminus TID for mapping */
273         else
274         {
275             auto mappedTid = storeTerminusInfo(mctpInfo, tid);
276             if (!mappedTid)
277             {
278                 lg2::error("Failed to store Terminus Info for terminus {TID}.",
279                            "TID", tid);
280                 co_return PLDM_ERROR;
281             }
282             isMapped = true;
283         }
284     }
285 
286     if (!isMapped)
287     {
288         // Assigning a tid. If it has been mapped, mapTid()
289         // returns the tid assigned before.
290         auto mappedTid = mapTid(mctpInfo);
291         if (!mappedTid)
292         {
293             lg2::error("Failed to store Terminus Info for terminus {TID}.",
294                        "TID", tid);
295             co_return PLDM_ERROR;
296         }
297 
298         tid = mappedTid.value();
299         rc = co_await setTidOverMctp(eid, tid);
300         if (rc != PLDM_SUCCESS)
301         {
302             lg2::error("Failed to Set terminus TID, error{ERROR}.", "ERROR",
303                        rc);
304             unmapTid(tid);
305             co_return rc;
306         }
307 
308         if (rc != PLDM_SUCCESS && rc != PLDM_ERROR_UNSUPPORTED_PLDM_CMD)
309         {
310             lg2::error("Terminus {TID} does not support SetTID command.", "TID",
311                        tid);
312             unmapTid(tid);
313             co_return rc;
314         }
315 
316         if (termini.contains(tid))
317         {
318             // the terminus has been discovered before
319             co_return PLDM_SUCCESS;
320         }
321     }
322     /* Discovery the mapped terminus */
323     uint64_t supportedTypes = 0;
324     rc = co_await getPLDMTypes(tid, supportedTypes);
325     if (rc)
326     {
327         lg2::error("Failed to Get PLDM Types for terminus {TID}, error {ERROR}",
328                    "TID", tid, "ERROR", rc);
329         co_return PLDM_ERROR;
330     }
331 
332     try
333     {
334         termini[tid] = std::make_shared<Terminus>(tid, supportedTypes);
335     }
336     catch (const sdbusplus::exception_t& e)
337     {
338         lg2::error("Failed to create terminus manager for terminus {TID}",
339                    "TID", tid);
340         co_return PLDM_ERROR;
341     }
342 
343     uint8_t type = PLDM_BASE;
344     auto size = PLDM_MAX_TYPES * (PLDM_MAX_CMDS_PER_TYPE / 8);
345     std::vector<uint8_t> pldmCmds(size);
346     while ((type < PLDM_MAX_TYPES))
347     {
348         if (!termini[tid]->doesSupportType(type))
349         {
350             type++;
351             continue;
352         }
353         std::vector<bitfield8_t> cmds(PLDM_MAX_CMDS_PER_TYPE / 8);
354         auto rc = co_await getPLDMCommands(tid, type, cmds.data());
355         if (rc)
356         {
357             lg2::error(
358                 "Failed to Get PLDM Commands for terminus {TID}, error {ERROR}",
359                 "TID", tid, "ERROR", rc);
360         }
361 
362         for (size_t i = 0; i < cmds.size(); i++)
363         {
364             auto idx = type * (PLDM_MAX_CMDS_PER_TYPE / 8) + i;
365             if (idx >= pldmCmds.size())
366             {
367                 lg2::error(
368                     "Calculated index {IDX} out of bounds for pldmCmds, type {TYPE}, command index {CMD_IDX}",
369                     "IDX", idx, "TYPE", type, "CMD_IDX", i);
370                 continue;
371             }
372             pldmCmds[idx] = cmds[i].byte;
373         }
374         type++;
375     }
376     termini[tid]->setSupportedCommands(pldmCmds);
377 
378     co_return PLDM_SUCCESS;
379 }
380 
381 exec::task<int> TerminusManager::sendRecvPldmMsgOverMctp(
382     mctp_eid_t eid, Request& request, const pldm_msg** responseMsg,
383     size_t* responseLen)
384 {
385     int rc = 0;
386     try
387     {
388         std::tie(rc, *responseMsg, *responseLen) =
389             co_await handler.sendRecvMsg(eid, std::move(request));
390     }
391     catch (const sdbusplus::exception_t& e)
392     {
393         lg2::error(
394             "Send and Receive PLDM message over MCTP throw error - {ERROR}.",
395             "ERROR", e);
396         co_return PLDM_ERROR;
397     }
398     catch (const int& e)
399     {
400         lg2::error(
401             "Send and Receive PLDM message over MCTP throw int error - {ERROR}.",
402             "ERROR", e);
403         co_return PLDM_ERROR;
404     }
405 
406     co_return rc;
407 }
408 
409 exec::task<int> TerminusManager::getTidOverMctp(mctp_eid_t eid, pldm_tid_t* tid)
410 {
411     auto instanceId = instanceIdDb.next(eid);
412     Request request(sizeof(pldm_msg_hdr));
413     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
414     auto rc = encode_get_tid_req(instanceId, requestMsg);
415     if (rc)
416     {
417         instanceIdDb.free(eid, instanceId);
418         lg2::error(
419             "Failed to encode request GetTID for endpoint ID {EID}, error {RC} ",
420             "EID", eid, "RC", rc);
421         co_return rc;
422     }
423 
424     const pldm_msg* responseMsg = nullptr;
425     size_t responseLen = 0;
426     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
427                                           &responseLen);
428     if (rc)
429     {
430         lg2::error("Failed to send GetTID for Endpoint {EID}, error {RC}",
431                    "EID", eid, "RC", rc);
432         co_return rc;
433     }
434 
435     uint8_t completionCode = 0;
436     rc = decode_get_tid_resp(responseMsg, responseLen, &completionCode, tid);
437     if (rc)
438     {
439         lg2::error(
440             "Failed to decode response GetTID for Endpoint ID {EID}, error {RC} ",
441             "EID", eid, "RC", rc);
442         co_return rc;
443     }
444 
445     if (completionCode != PLDM_SUCCESS)
446     {
447         lg2::error("Error : GetTID for Endpoint ID {EID}, complete code {CC}.",
448                    "EID", eid, "CC", completionCode);
449         co_return rc;
450     }
451 
452     co_return completionCode;
453 }
454 
455 exec::task<int> TerminusManager::setTidOverMctp(mctp_eid_t eid, pldm_tid_t tid)
456 {
457     auto instanceId = instanceIdDb.next(eid);
458     Request request(sizeof(pldm_msg_hdr) + sizeof(pldm_set_tid_req));
459     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
460     auto rc = encode_set_tid_req(instanceId, tid, requestMsg);
461     if (rc)
462     {
463         instanceIdDb.free(eid, instanceId);
464         lg2::error(
465             "Failed to encode request SetTID for endpoint ID {EID}, error {RC} ",
466             "EID", eid, "RC", rc);
467         co_return rc;
468     }
469 
470     const pldm_msg* responseMsg = nullptr;
471     size_t responseLen = 0;
472     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
473                                           &responseLen);
474     if (rc)
475     {
476         lg2::error("Failed to send SetTID for Endpoint {EID}, error {RC}",
477                    "EID", eid, "RC", rc);
478         co_return rc;
479     }
480 
481     if (responseMsg == NULL || responseLen != PLDM_SET_TID_RESP_BYTES)
482     {
483         lg2::error(
484             "Failed to decode response SetTID for Endpoint ID {EID}, error {RC} ",
485             "EID", eid, "RC", rc);
486         co_return PLDM_ERROR_INVALID_LENGTH;
487     }
488 
489     co_return responseMsg->payload[0];
490 }
491 
492 exec::task<int>
493     TerminusManager::getPLDMTypes(pldm_tid_t tid, uint64_t& supportedTypes)
494 {
495     Request request(sizeof(pldm_msg_hdr));
496     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
497     auto rc = encode_get_types_req(0, requestMsg);
498     if (rc)
499     {
500         lg2::error(
501             "Failed to encode request getPLDMTypes for terminus ID {TID}, error {RC} ",
502             "TID", tid, "RC", rc);
503         co_return rc;
504     }
505 
506     const pldm_msg* responseMsg = nullptr;
507     size_t responseLen = 0;
508 
509     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
510     if (rc)
511     {
512         lg2::error("Failed to send GetPLDMTypes for terminus {TID}, error {RC}",
513                    "TID", tid, "RC", rc);
514         co_return rc;
515     }
516 
517     uint8_t completionCode = 0;
518     bitfield8_t* types = reinterpret_cast<bitfield8_t*>(&supportedTypes);
519     rc =
520         decode_get_types_resp(responseMsg, responseLen, &completionCode, types);
521     if (rc)
522     {
523         lg2::error(
524             "Failed to decode response GetPLDMTypes for terminus ID {TID}, error {RC} ",
525             "TID", tid, "RC", rc);
526         co_return rc;
527     }
528 
529     if (completionCode != PLDM_SUCCESS)
530     {
531         lg2::error(
532             "Error : GetPLDMTypes for terminus ID {TID}, complete code {CC}.",
533             "TID", tid, "CC", completionCode);
534         co_return rc;
535     }
536     co_return completionCode;
537 }
538 
539 exec::task<int> TerminusManager::getPLDMCommands(pldm_tid_t tid, uint8_t type,
540                                                  bitfield8_t* supportedCmds)
541 {
542     Request request(sizeof(pldm_msg_hdr) + PLDM_GET_COMMANDS_REQ_BYTES);
543     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
544     ver32_t version{0xFF, 0xFF, 0xFF, 0xFF};
545 
546     auto rc = encode_get_commands_req(0, type, version, requestMsg);
547     if (rc)
548     {
549         lg2::error(
550             "Failed to encode request GetPLDMCommands for terminus ID {TID}, error {RC} ",
551             "TID", tid, "RC", rc);
552         co_return rc;
553     }
554 
555     const pldm_msg* responseMsg = nullptr;
556     size_t responseLen = 0;
557 
558     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
559     if (rc)
560     {
561         lg2::error(
562             "Failed to send GetPLDMCommands message for terminus {TID}, error {RC}",
563             "TID", tid, "RC", rc);
564         co_return rc;
565     }
566 
567     /* Process response */
568     uint8_t completionCode = 0;
569     rc = decode_get_commands_resp(responseMsg, responseLen, &completionCode,
570                                   supportedCmds);
571     if (rc)
572     {
573         lg2::error(
574             "Failed to decode response GetPLDMCommands for terminus ID {TID}, error {RC} ",
575             "TID", tid, "RC", rc);
576         co_return rc;
577     }
578 
579     if (completionCode != PLDM_SUCCESS)
580     {
581         lg2::error(
582             "Error : GetPLDMCommands for terminus ID {TID}, complete code {CC}.",
583             "TID", tid, "CC", completionCode);
584         co_return rc;
585     }
586 
587     co_return completionCode;
588 }
589 
590 exec::task<int> TerminusManager::sendRecvPldmMsg(
591     pldm_tid_t tid, Request& request, const pldm_msg** responseMsg,
592     size_t* responseLen)
593 {
594     /**
595      * Size of tidPool is `std::numeric_limits<pldm_tid_t>::max() + 1`
596      * tidPool[i] always exist
597      */
598     if (!tidPool[tid])
599     {
600         co_return PLDM_ERROR_NOT_READY;
601     }
602 
603     if (!transportLayerTable.contains(tid))
604     {
605         co_return PLDM_ERROR_NOT_READY;
606     }
607 
608     if (transportLayerTable[tid] != SupportedTransportLayer::MCTP)
609     {
610         co_return PLDM_ERROR_NOT_READY;
611     }
612 
613     auto mctpInfo = toMctpInfo(tid);
614     if (!mctpInfo.has_value())
615     {
616         co_return PLDM_ERROR_NOT_READY;
617     }
618 
619     auto eid = std::get<0>(mctpInfo.value());
620     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
621     requestMsg->hdr.instance_id = instanceIdDb.next(eid);
622     auto rc = co_await sendRecvPldmMsgOverMctp(eid, request, responseMsg,
623                                                responseLen);
624 
625     co_return rc;
626 }
627 
628 } // namespace platform_mc
629 } // namespace pldm
630