1 #include "terminus_manager.hpp"
2 
3 #include "manager.hpp"
4 
5 #include <phosphor-logging/lg2.hpp>
6 
7 PHOSPHOR_LOG2_USING;
8 
9 namespace pldm
10 {
11 namespace platform_mc
12 {
13 
toMctpInfo(const pldm_tid_t & tid)14 std::optional<MctpInfo> TerminusManager::toMctpInfo(const pldm_tid_t& tid)
15 {
16     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
17     {
18         return std::nullopt;
19     }
20 
21     if ((!this->transportLayerTable.contains(tid)) ||
22         (this->transportLayerTable[tid] != SupportedTransportLayer::MCTP))
23     {
24         return std::nullopt;
25     }
26 
27     auto mctpInfoIt = mctpInfoTable.find(tid);
28     if (mctpInfoIt == mctpInfoTable.end())
29     {
30         return std::nullopt;
31     }
32 
33     return mctpInfoIt->second;
34 }
35 
toTid(const MctpInfo & mctpInfo) const36 std::optional<pldm_tid_t> TerminusManager::toTid(const MctpInfo& mctpInfo) const
37 {
38     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
39     {
40         return std::nullopt;
41     }
42 
43     auto mctpInfoTableIt = std::find_if(
44         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
45         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
46                (std::get<3>(v.second) == std::get<3>(mctpInfo));
47     });
48     if (mctpInfoTableIt == mctpInfoTable.end())
49     {
50         return std::nullopt;
51     }
52     return mctpInfoTableIt->first;
53 }
54 
55 std::optional<pldm_tid_t>
storeTerminusInfo(const MctpInfo & mctpInfo,pldm_tid_t tid)56     TerminusManager::storeTerminusInfo(const MctpInfo& mctpInfo, pldm_tid_t tid)
57 {
58     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
59     {
60         return std::nullopt;
61     }
62 
63     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
64     {
65         return std::nullopt;
66     }
67 
68     if (tidPool[tid])
69     {
70         return std::nullopt;
71     }
72 
73     tidPool[tid] = true;
74     transportLayerTable[tid] = SupportedTransportLayer::MCTP;
75     mctpInfoTable[tid] = mctpInfo;
76 
77     return tid;
78 }
79 
mapTid(const MctpInfo & mctpInfo)80 std::optional<pldm_tid_t> TerminusManager::mapTid(const MctpInfo& mctpInfo)
81 {
82     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
83     {
84         return std::nullopt;
85     }
86 
87     auto mctpInfoTableIt = std::find_if(
88         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
89         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
90                (std::get<3>(v.second) == std::get<3>(mctpInfo));
91     });
92     if (mctpInfoTableIt != mctpInfoTable.end())
93     {
94         return mctpInfoTableIt->first;
95     }
96 
97     auto tidPoolIt = std::find(tidPool.begin(), tidPool.end(), false);
98     if (tidPoolIt == tidPool.end())
99     {
100         return std::nullopt;
101     }
102 
103     pldm_tid_t tid = std::distance(tidPool.begin(), tidPoolIt);
104     return storeTerminusInfo(mctpInfo, tid);
105 }
106 
unmapTid(const pldm_tid_t & tid)107 bool TerminusManager::unmapTid(const pldm_tid_t& tid)
108 {
109     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
110     {
111         return false;
112     }
113     tidPool[tid] = false;
114 
115     if (transportLayerTable.contains(tid))
116     {
117         transportLayerTable.erase(tid);
118     }
119 
120     if (mctpInfoTable.contains(tid))
121     {
122         mctpInfoTable.erase(tid);
123     }
124 
125     return true;
126 }
127 
discoverMctpTerminus(const MctpInfos & mctpInfos)128 void TerminusManager::discoverMctpTerminus(const MctpInfos& mctpInfos)
129 {
130     queuedMctpInfos.emplace(mctpInfos);
131     if (discoverMctpTerminusTaskHandle.has_value())
132     {
133         auto& [scope, rcOpt] = *discoverMctpTerminusTaskHandle;
134         if (!rcOpt.has_value())
135         {
136             return;
137         }
138         stdexec::sync_wait(scope.on_empty());
139         discoverMctpTerminusTaskHandle.reset();
140     }
141     auto& [scope, rcOpt] = discoverMctpTerminusTaskHandle.emplace();
142     scope.spawn(discoverMctpTerminusTask() |
143                     stdexec::then([&](int rc) { rcOpt.emplace(rc); }),
144                 exec::default_task_context<void>());
145 }
146 
findTeminusPtr(const MctpInfo & mctpInfo)147 auto TerminusManager::findTeminusPtr(const MctpInfo& mctpInfo)
148 {
149     auto foundIter = std::find_if(termini.begin(), termini.end(),
150                                   [&](const auto& terminusPair) {
151         auto terminusMctpInfo = toMctpInfo(terminusPair.first);
152         return (
153             terminusMctpInfo &&
154             (std::get<0>(terminusMctpInfo.value()) == std::get<0>(mctpInfo)) &&
155             (std::get<3>(terminusMctpInfo.value()) == std::get<3>(mctpInfo)));
156     });
157 
158     return foundIter;
159 }
160 
discoverMctpTerminusTask()161 exec::task<int> TerminusManager::discoverMctpTerminusTask()
162 {
163     while (!queuedMctpInfos.empty())
164     {
165         if (manager)
166         {
167             co_await manager->beforeDiscoverTerminus();
168         }
169 
170         const MctpInfos& mctpInfos = queuedMctpInfos.front();
171         for (const auto& mctpInfo : mctpInfos)
172         {
173             auto it = findTeminusPtr(mctpInfo);
174             if (it == termini.end())
175             {
176                 co_await initMctpTerminus(mctpInfo);
177             }
178         }
179 
180         if (manager)
181         {
182             co_await manager->afterDiscoverTerminus();
183         }
184 
185         queuedMctpInfos.pop();
186     }
187 
188     co_return PLDM_SUCCESS;
189 }
190 
removeMctpTerminus(const MctpInfos & mctpInfos)191 void TerminusManager::removeMctpTerminus(const MctpInfos& mctpInfos)
192 {
193     // remove terminus
194     for (const auto& mctpInfo : mctpInfos)
195     {
196         auto it = findTeminusPtr(mctpInfo);
197         if (it == termini.end())
198         {
199             continue;
200         }
201 
202         unmapTid(it->first);
203         termini.erase(it);
204     }
205 }
206 
initMctpTerminus(const MctpInfo & mctpInfo)207 exec::task<int> TerminusManager::initMctpTerminus(const MctpInfo& mctpInfo)
208 {
209     mctp_eid_t eid = std::get<0>(mctpInfo);
210     pldm_tid_t tid = 0;
211     bool isMapped = false;
212     auto rc = co_await getTidOverMctp(eid, &tid);
213     if (rc != PLDM_SUCCESS)
214     {
215         lg2::error("Failed to Get Terminus ID, error {ERROR}.", "ERROR", rc);
216         co_return PLDM_ERROR;
217     }
218 
219     if (tid == PLDM_TID_RESERVED)
220     {
221         lg2::error("Terminus responses the reserved {TID}.", "TID", tid);
222         co_return PLDM_ERROR;
223     }
224 
225     /* Terminus already has TID */
226     if (tid != PLDM_TID_UNASSIGNED)
227     {
228         /* TID is used by one discovered terminus */
229         auto it = termini.find(tid);
230         if (it != termini.end())
231         {
232             auto terminusMctpInfo = toMctpInfo(it->first);
233             /* The discovered terminus has the same MCTP Info */
234             if (terminusMctpInfo &&
235                 (std::get<0>(terminusMctpInfo.value()) ==
236                  std::get<0>(mctpInfo)) &&
237                 (std::get<3>(terminusMctpInfo.value()) ==
238                  std::get<3>(mctpInfo)))
239             {
240                 co_return PLDM_SUCCESS;
241             }
242             else
243             {
244                 /* ToDo:
245                  * Maybe the terminus supports multiple medium interfaces
246                  * Or the TID is used by other terminus.
247                  * Check the UUID to confirm.
248                  */
249                 isMapped = false;
250             }
251         }
252         /* Use the terminus TID for mapping */
253         else
254         {
255             auto mappedTid = storeTerminusInfo(mctpInfo, tid);
256             if (!mappedTid)
257             {
258                 lg2::error("Failed to store Terminus Info for terminus {TID}.",
259                            "TID", tid);
260                 co_return PLDM_ERROR;
261             }
262             isMapped = true;
263         }
264     }
265 
266     if (!isMapped)
267     {
268         // Assigning a tid. If it has been mapped, mapTid()
269         // returns the tid assigned before.
270         auto mappedTid = mapTid(mctpInfo);
271         if (!mappedTid)
272         {
273             lg2::error("Failed to store Terminus Info for terminus {TID}.",
274                        "TID", tid);
275             co_return PLDM_ERROR;
276         }
277 
278         tid = mappedTid.value();
279         rc = co_await setTidOverMctp(eid, tid);
280         if (rc != PLDM_SUCCESS)
281         {
282             lg2::error("Failed to Set terminus TID, error{ERROR}.", "ERROR",
283                        rc);
284             unmapTid(tid);
285             co_return rc;
286         }
287 
288         if (rc != PLDM_SUCCESS && rc != PLDM_ERROR_UNSUPPORTED_PLDM_CMD)
289         {
290             lg2::error("Terminus {TID} does not support SetTID command.", "TID",
291                        tid);
292             unmapTid(tid);
293             co_return rc;
294         }
295 
296         if (termini.contains(tid))
297         {
298             // the terminus has been discovered before
299             co_return PLDM_SUCCESS;
300         }
301     }
302     /* Discovery the mapped terminus */
303     uint64_t supportedTypes = 0;
304     rc = co_await getPLDMTypes(tid, supportedTypes);
305     if (rc)
306     {
307         lg2::error("Failed to Get PLDM Types for terminus {TID}, error {ERROR}",
308                    "TID", tid, "ERROR", rc);
309         co_return PLDM_ERROR;
310     }
311 
312     try
313     {
314         termini[tid] = std::make_shared<Terminus>(tid, supportedTypes);
315     }
316     catch (const sdbusplus::exception_t& e)
317     {
318         lg2::error("Failed to create terminus manager for terminus {TID}",
319                    "TID", tid);
320         co_return PLDM_ERROR;
321     }
322 
323     uint8_t type = PLDM_BASE;
324     auto size = PLDM_MAX_TYPES * (PLDM_MAX_CMDS_PER_TYPE / 8);
325     std::vector<uint8_t> pldmCmds(size);
326     while ((type < PLDM_MAX_TYPES))
327     {
328         if (!termini[tid]->doesSupportType(type))
329         {
330             type++;
331             continue;
332         }
333         std::vector<bitfield8_t> cmds(PLDM_MAX_CMDS_PER_TYPE / 8);
334         auto rc = co_await getPLDMCommands(tid, type, cmds.data());
335         if (rc)
336         {
337             lg2::error(
338                 "Failed to Get PLDM Commands for terminus {TID}, error {ERROR}",
339                 "TID", tid, "ERROR", rc);
340         }
341 
342         for (size_t i = 0; i < cmds.size(); i++)
343         {
344             auto idx = type * (PLDM_MAX_CMDS_PER_TYPE / 8) + i;
345             if (idx >= pldmCmds.size())
346             {
347                 lg2::error(
348                     "Calculated index {IDX} out of bounds for pldmCmds, type {TYPE}, command index {CMD_IDX}",
349                     "IDX", idx, "TYPE", type, "CMD_IDX", i);
350                 continue;
351             }
352             pldmCmds[idx] = cmds[i].byte;
353         }
354         type++;
355     }
356     termini[tid]->setSupportedCommands(pldmCmds);
357 
358     co_return PLDM_SUCCESS;
359 }
360 
361 exec::task<int>
sendRecvPldmMsgOverMctp(mctp_eid_t eid,Request & request,const pldm_msg ** responseMsg,size_t * responseLen)362     TerminusManager::sendRecvPldmMsgOverMctp(mctp_eid_t eid, Request& request,
363                                              const pldm_msg** responseMsg,
364                                              size_t* responseLen)
365 {
366     try
367     {
368         std::tie(*responseMsg, *responseLen) =
369             co_await handler.sendRecvMsg(eid, std::move(request));
370         co_return PLDM_SUCCESS;
371     }
372     catch (const sdbusplus::exception_t& e)
373     {
374         lg2::error(
375             "Send and Receive PLDM message over MCTP failed with error - {ERROR}.",
376             "ERROR", e);
377         co_return PLDM_ERROR;
378     }
379 }
380 
getTidOverMctp(mctp_eid_t eid,pldm_tid_t * tid)381 exec::task<int> TerminusManager::getTidOverMctp(mctp_eid_t eid, pldm_tid_t* tid)
382 {
383     auto instanceId = instanceIdDb.next(eid);
384     Request request(sizeof(pldm_msg_hdr));
385     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
386     auto rc = encode_get_tid_req(instanceId, requestMsg);
387     if (rc)
388     {
389         instanceIdDb.free(eid, instanceId);
390         lg2::error(
391             "Failed to encode request GetTID for endpoint ID {EID}, error {RC} ",
392             "EID", eid, "RC", rc);
393         co_return rc;
394     }
395 
396     const pldm_msg* responseMsg = nullptr;
397     size_t responseLen = 0;
398     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
399                                           &responseLen);
400     if (rc)
401     {
402         lg2::error("Failed to send GetTID for Endpoint {EID}, error {RC}",
403                    "EID", eid, "RC", rc);
404         co_return rc;
405     }
406 
407     uint8_t completionCode = 0;
408     rc = decode_get_tid_resp(responseMsg, responseLen, &completionCode, tid);
409     if (rc)
410     {
411         lg2::error(
412             "Failed to decode response GetTID for Endpoint ID {EID}, error {RC} ",
413             "EID", eid, "RC", rc);
414         co_return rc;
415     }
416 
417     if (completionCode != PLDM_SUCCESS)
418     {
419         lg2::error("Error : GetTID for Endpoint ID {EID}, complete code {CC}.",
420                    "EID", eid, "CC", completionCode);
421         co_return rc;
422     }
423 
424     co_return completionCode;
425 }
426 
setTidOverMctp(mctp_eid_t eid,pldm_tid_t tid)427 exec::task<int> TerminusManager::setTidOverMctp(mctp_eid_t eid, pldm_tid_t tid)
428 {
429     auto instanceId = instanceIdDb.next(eid);
430     Request request(sizeof(pldm_msg_hdr) + sizeof(pldm_set_tid_req));
431     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
432     auto rc = encode_set_tid_req(instanceId, tid, requestMsg);
433     if (rc)
434     {
435         instanceIdDb.free(eid, instanceId);
436         lg2::error(
437             "Failed to encode request SetTID for endpoint ID {EID}, error {RC} ",
438             "EID", eid, "RC", rc);
439         co_return rc;
440     }
441 
442     const pldm_msg* responseMsg = nullptr;
443     size_t responseLen = 0;
444     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
445                                           &responseLen);
446     if (rc)
447     {
448         lg2::error("Failed to send SetTID for Endpoint {EID}, error {RC}",
449                    "EID", eid, "RC", rc);
450         co_return rc;
451     }
452 
453     if (responseMsg == NULL || responseLen != PLDM_SET_TID_RESP_BYTES)
454     {
455         lg2::error(
456             "Failed to decode response SetTID for Endpoint ID {EID}, error {RC} ",
457             "EID", eid, "RC", rc);
458         co_return PLDM_ERROR_INVALID_LENGTH;
459     }
460 
461     co_return responseMsg->payload[0];
462 }
463 
getPLDMTypes(pldm_tid_t tid,uint64_t & supportedTypes)464 exec::task<int> TerminusManager::getPLDMTypes(pldm_tid_t tid,
465                                               uint64_t& supportedTypes)
466 {
467     Request request(sizeof(pldm_msg_hdr));
468     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
469     auto rc = encode_get_types_req(0, requestMsg);
470     if (rc)
471     {
472         lg2::error(
473             "Failed to encode request getPLDMTypes for terminus ID {TID}, error {RC} ",
474             "TID", tid, "RC", rc);
475         co_return rc;
476     }
477 
478     const pldm_msg* responseMsg = nullptr;
479     size_t responseLen = 0;
480 
481     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
482     if (rc)
483     {
484         lg2::error("Failed to send GetPLDMTypes for terminus {TID}, error {RC}",
485                    "TID", tid, "RC", rc);
486         co_return rc;
487     }
488 
489     uint8_t completionCode = 0;
490     bitfield8_t* types = reinterpret_cast<bitfield8_t*>(&supportedTypes);
491     rc = decode_get_types_resp(responseMsg, responseLen, &completionCode,
492                                types);
493     if (rc)
494     {
495         lg2::error(
496             "Failed to decode response GetPLDMTypes for terminus ID {TID}, error {RC} ",
497             "TID", tid, "RC", rc);
498         co_return rc;
499     }
500 
501     if (completionCode != PLDM_SUCCESS)
502     {
503         lg2::error(
504             "Error : GetPLDMTypes for terminus ID {TID}, complete code {CC}.",
505             "TID", tid, "CC", completionCode);
506         co_return rc;
507     }
508     co_return completionCode;
509 }
510 
getPLDMCommands(pldm_tid_t tid,uint8_t type,bitfield8_t * supportedCmds)511 exec::task<int> TerminusManager::getPLDMCommands(pldm_tid_t tid, uint8_t type,
512                                                  bitfield8_t* supportedCmds)
513 {
514     Request request(sizeof(pldm_msg_hdr) + PLDM_GET_COMMANDS_REQ_BYTES);
515     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
516     ver32_t version{0xFF, 0xFF, 0xFF, 0xFF};
517 
518     auto rc = encode_get_commands_req(0, type, version, requestMsg);
519     if (rc)
520     {
521         lg2::error(
522             "Failed to encode request GetPLDMCommands for terminus ID {TID}, error {RC} ",
523             "TID", tid, "RC", rc);
524         co_return rc;
525     }
526 
527     const pldm_msg* responseMsg = nullptr;
528     size_t responseLen = 0;
529 
530     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
531     if (rc)
532     {
533         lg2::error(
534             "Failed to send GetPLDMCommands message for terminus {TID}, error {RC}",
535             "TID", tid, "RC", rc);
536         co_return rc;
537     }
538 
539     /* Process response */
540     uint8_t completionCode = 0;
541     rc = decode_get_commands_resp(responseMsg, responseLen, &completionCode,
542                                   supportedCmds);
543     if (rc)
544     {
545         lg2::error(
546             "Failed to decode response GetPLDMCommands for terminus ID {TID}, error {RC} ",
547             "TID", tid, "RC", rc);
548         co_return rc;
549     }
550 
551     if (completionCode != PLDM_SUCCESS)
552     {
553         lg2::error(
554             "Error : GetPLDMCommands for terminus ID {TID}, complete code {CC}.",
555             "TID", tid, "CC", completionCode);
556         co_return rc;
557     }
558 
559     co_return completionCode;
560 }
561 
sendRecvPldmMsg(pldm_tid_t tid,Request & request,const pldm_msg ** responseMsg,size_t * responseLen)562 exec::task<int> TerminusManager::sendRecvPldmMsg(pldm_tid_t tid,
563                                                  Request& request,
564                                                  const pldm_msg** responseMsg,
565                                                  size_t* responseLen)
566 {
567     /**
568      * Size of tidPool is `std::numeric_limits<pldm_tid_t>::max() + 1`
569      * tidPool[i] always exist
570      */
571     if (!tidPool[tid])
572     {
573         co_return PLDM_ERROR_NOT_READY;
574     }
575 
576     if (!transportLayerTable.contains(tid))
577     {
578         co_return PLDM_ERROR_NOT_READY;
579     }
580 
581     if (transportLayerTable[tid] != SupportedTransportLayer::MCTP)
582     {
583         co_return PLDM_ERROR_NOT_READY;
584     }
585 
586     auto mctpInfo = toMctpInfo(tid);
587     if (!mctpInfo.has_value())
588     {
589         co_return PLDM_ERROR_NOT_READY;
590     }
591 
592     auto eid = std::get<0>(mctpInfo.value());
593     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
594     requestMsg->hdr.instance_id = instanceIdDb.next(eid);
595     auto rc = co_await sendRecvPldmMsgOverMctp(eid, request, responseMsg,
596                                                responseLen);
597 
598     if (responseMsg == nullptr || !responseLen)
599     {
600         co_return PLDM_ERROR_INVALID_DATA;
601     }
602 
603     co_return rc;
604 }
605 
606 } // namespace platform_mc
607 } // namespace pldm
608