1 #include "terminus_manager.hpp"
2 
3 #include "manager.hpp"
4 
5 #include <phosphor-logging/lg2.hpp>
6 
7 PHOSPHOR_LOG2_USING;
8 
9 namespace pldm
10 {
11 namespace platform_mc
12 {
13 
14 std::optional<MctpInfo> TerminusManager::toMctpInfo(const pldm_tid_t& tid)
15 {
16     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
17     {
18         return std::nullopt;
19     }
20 
21     if ((!this->transportLayerTable.contains(tid)) ||
22         (this->transportLayerTable[tid] != SupportedTransportLayer::MCTP))
23     {
24         return std::nullopt;
25     }
26 
27     auto mctpInfoIt = mctpInfoTable.find(tid);
28     if (mctpInfoIt == mctpInfoTable.end())
29     {
30         return std::nullopt;
31     }
32 
33     return mctpInfoIt->second;
34 }
35 
36 std::optional<pldm_tid_t> TerminusManager::toTid(const MctpInfo& mctpInfo) const
37 {
38     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
39     {
40         return std::nullopt;
41     }
42 
43     auto mctpInfoTableIt = std::find_if(
44         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
45         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
46                (std::get<3>(v.second) == std::get<3>(mctpInfo));
47     });
48     if (mctpInfoTableIt == mctpInfoTable.end())
49     {
50         return std::nullopt;
51     }
52     return mctpInfoTableIt->first;
53 }
54 
55 std::optional<pldm_tid_t>
56     TerminusManager::storeTerminusInfo(const MctpInfo& mctpInfo, pldm_tid_t tid)
57 {
58     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
59     {
60         return std::nullopt;
61     }
62 
63     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
64     {
65         return std::nullopt;
66     }
67 
68     if (tidPool[tid])
69     {
70         return std::nullopt;
71     }
72 
73     tidPool[tid] = true;
74     transportLayerTable[tid] = SupportedTransportLayer::MCTP;
75     mctpInfoTable[tid] = mctpInfo;
76 
77     return tid;
78 }
79 
80 std::optional<pldm_tid_t> TerminusManager::mapTid(const MctpInfo& mctpInfo)
81 {
82     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
83     {
84         return std::nullopt;
85     }
86 
87     auto mctpInfoTableIt = std::find_if(
88         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
89         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
90                (std::get<3>(v.second) == std::get<3>(mctpInfo));
91     });
92     if (mctpInfoTableIt != mctpInfoTable.end())
93     {
94         return mctpInfoTableIt->first;
95     }
96 
97     auto tidPoolIt = std::find(tidPool.begin(), tidPool.end(), false);
98     if (tidPoolIt == tidPool.end())
99     {
100         return std::nullopt;
101     }
102 
103     pldm_tid_t tid = std::distance(tidPool.begin(), tidPoolIt);
104     return storeTerminusInfo(mctpInfo, tid);
105 }
106 
107 bool TerminusManager::unmapTid(const pldm_tid_t& tid)
108 {
109     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
110     {
111         return false;
112     }
113     tidPool[tid] = false;
114 
115     if (transportLayerTable.contains(tid))
116     {
117         transportLayerTable.erase(tid);
118     }
119 
120     if (mctpInfoTable.contains(tid))
121     {
122         mctpInfoTable.erase(tid);
123     }
124 
125     return true;
126 }
127 
128 void TerminusManager::discoverMctpTerminus(const MctpInfos& mctpInfos)
129 {
130     queuedMctpInfos.emplace(mctpInfos);
131     if (discoverMctpTerminusTaskHandle.has_value())
132     {
133         auto& [scope, rcOpt] = *discoverMctpTerminusTaskHandle;
134         if (!rcOpt.has_value())
135         {
136             return;
137         }
138         stdexec::sync_wait(scope.on_empty());
139         discoverMctpTerminusTaskHandle.reset();
140     }
141     auto& [scope, rcOpt] = discoverMctpTerminusTaskHandle.emplace();
142     scope.spawn(discoverMctpTerminusTask() |
143                     stdexec::then([&](int rc) { rcOpt.emplace(rc); }),
144                 exec::default_task_context<void>());
145 }
146 
147 auto TerminusManager::findTerminusPtr(const MctpInfo& mctpInfo)
148 {
149     auto foundIter = std::find_if(termini.begin(), termini.end(),
150                                   [&](const auto& terminusPair) {
151         auto terminusMctpInfo = toMctpInfo(terminusPair.first);
152         return (
153             terminusMctpInfo &&
154             (std::get<0>(terminusMctpInfo.value()) == std::get<0>(mctpInfo)) &&
155             (std::get<3>(terminusMctpInfo.value()) == std::get<3>(mctpInfo)));
156     });
157 
158     return foundIter;
159 }
160 
161 exec::task<int> TerminusManager::discoverMctpTerminusTask()
162 {
163     while (!queuedMctpInfos.empty())
164     {
165         if (manager)
166         {
167             co_await manager->beforeDiscoverTerminus();
168         }
169 
170         const MctpInfos& mctpInfos = queuedMctpInfos.front();
171         for (const auto& mctpInfo : mctpInfos)
172         {
173             auto it = findTerminusPtr(mctpInfo);
174             if (it == termini.end())
175             {
176                 co_await initMctpTerminus(mctpInfo);
177             }
178         }
179 
180         if (manager)
181         {
182             co_await manager->afterDiscoverTerminus();
183         }
184 
185         queuedMctpInfos.pop();
186     }
187 
188     co_return PLDM_SUCCESS;
189 }
190 
191 void TerminusManager::removeMctpTerminus(const MctpInfos& mctpInfos)
192 {
193     // remove terminus
194     for (const auto& mctpInfo : mctpInfos)
195     {
196         auto it = findTerminusPtr(mctpInfo);
197         if (it == termini.end())
198         {
199             continue;
200         }
201 
202         unmapTid(it->first);
203         termini.erase(it);
204     }
205 }
206 
207 exec::task<int> TerminusManager::initMctpTerminus(const MctpInfo& mctpInfo)
208 {
209     mctp_eid_t eid = std::get<0>(mctpInfo);
210     pldm_tid_t tid = 0;
211     bool isMapped = false;
212     auto rc = co_await getTidOverMctp(eid, &tid);
213     if (rc != PLDM_SUCCESS)
214     {
215         lg2::error("Failed to Get Terminus ID, error {ERROR}.", "ERROR", rc);
216         co_return PLDM_ERROR;
217     }
218 
219     if (tid == PLDM_TID_RESERVED)
220     {
221         lg2::error("Terminus responses the reserved {TID}.", "TID", tid);
222         co_return PLDM_ERROR;
223     }
224 
225     /* Terminus already has TID */
226     if (tid != PLDM_TID_UNASSIGNED)
227     {
228         /* TID is used by one discovered terminus */
229         auto it = termini.find(tid);
230         if (it != termini.end())
231         {
232             auto terminusMctpInfo = toMctpInfo(it->first);
233             /* The discovered terminus has the same MCTP Info */
234             if (terminusMctpInfo &&
235                 (std::get<0>(terminusMctpInfo.value()) ==
236                  std::get<0>(mctpInfo)) &&
237                 (std::get<3>(terminusMctpInfo.value()) ==
238                  std::get<3>(mctpInfo)))
239             {
240                 co_return PLDM_SUCCESS;
241             }
242             else
243             {
244                 /* ToDo:
245                  * Maybe the terminus supports multiple medium interfaces
246                  * Or the TID is used by other terminus.
247                  * Check the UUID to confirm.
248                  */
249                 isMapped = false;
250             }
251         }
252         /* Use the terminus TID for mapping */
253         else
254         {
255             auto mappedTid = storeTerminusInfo(mctpInfo, tid);
256             if (!mappedTid)
257             {
258                 lg2::error("Failed to store Terminus Info for terminus {TID}.",
259                            "TID", tid);
260                 co_return PLDM_ERROR;
261             }
262             isMapped = true;
263         }
264     }
265 
266     if (!isMapped)
267     {
268         // Assigning a tid. If it has been mapped, mapTid()
269         // returns the tid assigned before.
270         auto mappedTid = mapTid(mctpInfo);
271         if (!mappedTid)
272         {
273             lg2::error("Failed to store Terminus Info for terminus {TID}.",
274                        "TID", tid);
275             co_return PLDM_ERROR;
276         }
277 
278         tid = mappedTid.value();
279         rc = co_await setTidOverMctp(eid, tid);
280         if (rc != PLDM_SUCCESS)
281         {
282             lg2::error("Failed to Set terminus TID, error{ERROR}.", "ERROR",
283                        rc);
284             unmapTid(tid);
285             co_return rc;
286         }
287 
288         if (rc != PLDM_SUCCESS && rc != PLDM_ERROR_UNSUPPORTED_PLDM_CMD)
289         {
290             lg2::error("Terminus {TID} does not support SetTID command.", "TID",
291                        tid);
292             unmapTid(tid);
293             co_return rc;
294         }
295 
296         if (termini.contains(tid))
297         {
298             // the terminus has been discovered before
299             co_return PLDM_SUCCESS;
300         }
301     }
302     /* Discovery the mapped terminus */
303     uint64_t supportedTypes = 0;
304     rc = co_await getPLDMTypes(tid, supportedTypes);
305     if (rc)
306     {
307         lg2::error("Failed to Get PLDM Types for terminus {TID}, error {ERROR}",
308                    "TID", tid, "ERROR", rc);
309         co_return PLDM_ERROR;
310     }
311 
312     try
313     {
314         termini[tid] = std::make_shared<Terminus>(tid, supportedTypes);
315     }
316     catch (const sdbusplus::exception_t& e)
317     {
318         lg2::error("Failed to create terminus manager for terminus {TID}",
319                    "TID", tid);
320         co_return PLDM_ERROR;
321     }
322 
323     uint8_t type = PLDM_BASE;
324     auto size = PLDM_MAX_TYPES * (PLDM_MAX_CMDS_PER_TYPE / 8);
325     std::vector<uint8_t> pldmCmds(size);
326     while ((type < PLDM_MAX_TYPES))
327     {
328         if (!termini[tid]->doesSupportType(type))
329         {
330             type++;
331             continue;
332         }
333         std::vector<bitfield8_t> cmds(PLDM_MAX_CMDS_PER_TYPE / 8);
334         auto rc = co_await getPLDMCommands(tid, type, cmds.data());
335         if (rc)
336         {
337             lg2::error(
338                 "Failed to Get PLDM Commands for terminus {TID}, error {ERROR}",
339                 "TID", tid, "ERROR", rc);
340         }
341 
342         for (size_t i = 0; i < cmds.size(); i++)
343         {
344             auto idx = type * (PLDM_MAX_CMDS_PER_TYPE / 8) + i;
345             if (idx >= pldmCmds.size())
346             {
347                 lg2::error(
348                     "Calculated index {IDX} out of bounds for pldmCmds, type {TYPE}, command index {CMD_IDX}",
349                     "IDX", idx, "TYPE", type, "CMD_IDX", i);
350                 continue;
351             }
352             pldmCmds[idx] = cmds[i].byte;
353         }
354         type++;
355     }
356     termini[tid]->setSupportedCommands(pldmCmds);
357 
358     co_return PLDM_SUCCESS;
359 }
360 
361 exec::task<int>
362     TerminusManager::sendRecvPldmMsgOverMctp(mctp_eid_t eid, Request& request,
363                                              const pldm_msg** responseMsg,
364                                              size_t* responseLen)
365 {
366     int rc = 0;
367     try
368     {
369         std::tie(rc, *responseMsg, *responseLen) =
370             co_await handler.sendRecvMsg(eid, std::move(request));
371     }
372     catch (const sdbusplus::exception_t& e)
373     {
374         lg2::error(
375             "Send and Receive PLDM message over MCTP throw error - {ERROR}.",
376             "ERROR", e);
377         co_return PLDM_ERROR;
378     }
379     catch (const int& e)
380     {
381         lg2::error(
382             "Send and Receive PLDM message over MCTP throw int error - {ERROR}.",
383             "ERROR", e);
384         co_return PLDM_ERROR;
385     }
386 
387     co_return rc;
388 }
389 
390 exec::task<int> TerminusManager::getTidOverMctp(mctp_eid_t eid, pldm_tid_t* tid)
391 {
392     auto instanceId = instanceIdDb.next(eid);
393     Request request(sizeof(pldm_msg_hdr));
394     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
395     auto rc = encode_get_tid_req(instanceId, requestMsg);
396     if (rc)
397     {
398         instanceIdDb.free(eid, instanceId);
399         lg2::error(
400             "Failed to encode request GetTID for endpoint ID {EID}, error {RC} ",
401             "EID", eid, "RC", rc);
402         co_return rc;
403     }
404 
405     const pldm_msg* responseMsg = nullptr;
406     size_t responseLen = 0;
407     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
408                                           &responseLen);
409     if (rc)
410     {
411         lg2::error("Failed to send GetTID for Endpoint {EID}, error {RC}",
412                    "EID", eid, "RC", rc);
413         co_return rc;
414     }
415 
416     uint8_t completionCode = 0;
417     rc = decode_get_tid_resp(responseMsg, responseLen, &completionCode, tid);
418     if (rc)
419     {
420         lg2::error(
421             "Failed to decode response GetTID for Endpoint ID {EID}, error {RC} ",
422             "EID", eid, "RC", rc);
423         co_return rc;
424     }
425 
426     if (completionCode != PLDM_SUCCESS)
427     {
428         lg2::error("Error : GetTID for Endpoint ID {EID}, complete code {CC}.",
429                    "EID", eid, "CC", completionCode);
430         co_return rc;
431     }
432 
433     co_return completionCode;
434 }
435 
436 exec::task<int> TerminusManager::setTidOverMctp(mctp_eid_t eid, pldm_tid_t tid)
437 {
438     auto instanceId = instanceIdDb.next(eid);
439     Request request(sizeof(pldm_msg_hdr) + sizeof(pldm_set_tid_req));
440     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
441     auto rc = encode_set_tid_req(instanceId, tid, requestMsg);
442     if (rc)
443     {
444         instanceIdDb.free(eid, instanceId);
445         lg2::error(
446             "Failed to encode request SetTID for endpoint ID {EID}, error {RC} ",
447             "EID", eid, "RC", rc);
448         co_return rc;
449     }
450 
451     const pldm_msg* responseMsg = nullptr;
452     size_t responseLen = 0;
453     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
454                                           &responseLen);
455     if (rc)
456     {
457         lg2::error("Failed to send SetTID for Endpoint {EID}, error {RC}",
458                    "EID", eid, "RC", rc);
459         co_return rc;
460     }
461 
462     if (responseMsg == NULL || responseLen != PLDM_SET_TID_RESP_BYTES)
463     {
464         lg2::error(
465             "Failed to decode response SetTID for Endpoint ID {EID}, error {RC} ",
466             "EID", eid, "RC", rc);
467         co_return PLDM_ERROR_INVALID_LENGTH;
468     }
469 
470     co_return responseMsg->payload[0];
471 }
472 
473 exec::task<int> TerminusManager::getPLDMTypes(pldm_tid_t tid,
474                                               uint64_t& supportedTypes)
475 {
476     Request request(sizeof(pldm_msg_hdr));
477     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
478     auto rc = encode_get_types_req(0, requestMsg);
479     if (rc)
480     {
481         lg2::error(
482             "Failed to encode request getPLDMTypes for terminus ID {TID}, error {RC} ",
483             "TID", tid, "RC", rc);
484         co_return rc;
485     }
486 
487     const pldm_msg* responseMsg = nullptr;
488     size_t responseLen = 0;
489 
490     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
491     if (rc)
492     {
493         lg2::error("Failed to send GetPLDMTypes for terminus {TID}, error {RC}",
494                    "TID", tid, "RC", rc);
495         co_return rc;
496     }
497 
498     uint8_t completionCode = 0;
499     bitfield8_t* types = reinterpret_cast<bitfield8_t*>(&supportedTypes);
500     rc = decode_get_types_resp(responseMsg, responseLen, &completionCode,
501                                types);
502     if (rc)
503     {
504         lg2::error(
505             "Failed to decode response GetPLDMTypes for terminus ID {TID}, error {RC} ",
506             "TID", tid, "RC", rc);
507         co_return rc;
508     }
509 
510     if (completionCode != PLDM_SUCCESS)
511     {
512         lg2::error(
513             "Error : GetPLDMTypes for terminus ID {TID}, complete code {CC}.",
514             "TID", tid, "CC", completionCode);
515         co_return rc;
516     }
517     co_return completionCode;
518 }
519 
520 exec::task<int> TerminusManager::getPLDMCommands(pldm_tid_t tid, uint8_t type,
521                                                  bitfield8_t* supportedCmds)
522 {
523     Request request(sizeof(pldm_msg_hdr) + PLDM_GET_COMMANDS_REQ_BYTES);
524     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
525     ver32_t version{0xFF, 0xFF, 0xFF, 0xFF};
526 
527     auto rc = encode_get_commands_req(0, type, version, requestMsg);
528     if (rc)
529     {
530         lg2::error(
531             "Failed to encode request GetPLDMCommands for terminus ID {TID}, error {RC} ",
532             "TID", tid, "RC", rc);
533         co_return rc;
534     }
535 
536     const pldm_msg* responseMsg = nullptr;
537     size_t responseLen = 0;
538 
539     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
540     if (rc)
541     {
542         lg2::error(
543             "Failed to send GetPLDMCommands message for terminus {TID}, error {RC}",
544             "TID", tid, "RC", rc);
545         co_return rc;
546     }
547 
548     /* Process response */
549     uint8_t completionCode = 0;
550     rc = decode_get_commands_resp(responseMsg, responseLen, &completionCode,
551                                   supportedCmds);
552     if (rc)
553     {
554         lg2::error(
555             "Failed to decode response GetPLDMCommands for terminus ID {TID}, error {RC} ",
556             "TID", tid, "RC", rc);
557         co_return rc;
558     }
559 
560     if (completionCode != PLDM_SUCCESS)
561     {
562         lg2::error(
563             "Error : GetPLDMCommands for terminus ID {TID}, complete code {CC}.",
564             "TID", tid, "CC", completionCode);
565         co_return rc;
566     }
567 
568     co_return completionCode;
569 }
570 
571 exec::task<int> TerminusManager::sendRecvPldmMsg(pldm_tid_t tid,
572                                                  Request& request,
573                                                  const pldm_msg** responseMsg,
574                                                  size_t* responseLen)
575 {
576     /**
577      * Size of tidPool is `std::numeric_limits<pldm_tid_t>::max() + 1`
578      * tidPool[i] always exist
579      */
580     if (!tidPool[tid])
581     {
582         co_return PLDM_ERROR_NOT_READY;
583     }
584 
585     if (!transportLayerTable.contains(tid))
586     {
587         co_return PLDM_ERROR_NOT_READY;
588     }
589 
590     if (transportLayerTable[tid] != SupportedTransportLayer::MCTP)
591     {
592         co_return PLDM_ERROR_NOT_READY;
593     }
594 
595     auto mctpInfo = toMctpInfo(tid);
596     if (!mctpInfo.has_value())
597     {
598         co_return PLDM_ERROR_NOT_READY;
599     }
600 
601     auto eid = std::get<0>(mctpInfo.value());
602     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
603     requestMsg->hdr.instance_id = instanceIdDb.next(eid);
604     auto rc = co_await sendRecvPldmMsgOverMctp(eid, request, responseMsg,
605                                                responseLen);
606 
607     co_return rc;
608 }
609 
610 } // namespace platform_mc
611 } // namespace pldm
612