1 #include "terminus_manager.hpp"
2 
3 #include "manager.hpp"
4 
5 #include <phosphor-logging/lg2.hpp>
6 
7 PHOSPHOR_LOG2_USING;
8 
9 namespace pldm
10 {
11 namespace platform_mc
12 {
13 
14 std::optional<MctpInfo> TerminusManager::toMctpInfo(const pldm_tid_t& tid)
15 {
16     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
17     {
18         return std::nullopt;
19     }
20 
21     if ((!this->transportLayerTable.contains(tid)) ||
22         (this->transportLayerTable[tid] != SupportedTransportLayer::MCTP))
23     {
24         return std::nullopt;
25     }
26 
27     auto mctpInfoIt = mctpInfoTable.find(tid);
28     if (mctpInfoIt == mctpInfoTable.end())
29     {
30         return std::nullopt;
31     }
32 
33     return mctpInfoIt->second;
34 }
35 
36 std::optional<pldm_tid_t> TerminusManager::toTid(const MctpInfo& mctpInfo) const
37 {
38     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
39     {
40         return std::nullopt;
41     }
42 
43     auto mctpInfoTableIt = std::find_if(
44         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
45         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
46                (std::get<3>(v.second) == std::get<3>(mctpInfo));
47     });
48     if (mctpInfoTableIt == mctpInfoTable.end())
49     {
50         return std::nullopt;
51     }
52     return mctpInfoTableIt->first;
53 }
54 
55 std::optional<pldm_tid_t>
56     TerminusManager::storeTerminusInfo(const MctpInfo& mctpInfo, pldm_tid_t tid)
57 {
58     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
59     {
60         return std::nullopt;
61     }
62 
63     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
64     {
65         return std::nullopt;
66     }
67 
68     if (tidPool[tid])
69     {
70         return std::nullopt;
71     }
72 
73     tidPool[tid] = true;
74     transportLayerTable[tid] = SupportedTransportLayer::MCTP;
75     mctpInfoTable[tid] = mctpInfo;
76 
77     return tid;
78 }
79 
80 std::optional<pldm_tid_t> TerminusManager::mapTid(const MctpInfo& mctpInfo)
81 {
82     if (!pldm::utils::isValidEID(std::get<0>(mctpInfo)))
83     {
84         return std::nullopt;
85     }
86 
87     auto mctpInfoTableIt = std::find_if(
88         mctpInfoTable.begin(), mctpInfoTable.end(), [&mctpInfo](auto& v) {
89         return (std::get<0>(v.second) == std::get<0>(mctpInfo)) &&
90                (std::get<3>(v.second) == std::get<3>(mctpInfo));
91     });
92     if (mctpInfoTableIt != mctpInfoTable.end())
93     {
94         return mctpInfoTableIt->first;
95     }
96 
97     auto tidPoolIt = std::find(tidPool.begin(), tidPool.end(), false);
98     if (tidPoolIt == tidPool.end())
99     {
100         return std::nullopt;
101     }
102 
103     pldm_tid_t tid = std::distance(tidPool.begin(), tidPoolIt);
104     return storeTerminusInfo(mctpInfo, tid);
105 }
106 
107 bool TerminusManager::unmapTid(const pldm_tid_t& tid)
108 {
109     if (tid == PLDM_TID_UNASSIGNED || tid == PLDM_TID_RESERVED)
110     {
111         return false;
112     }
113     tidPool[tid] = false;
114 
115     if (transportLayerTable.contains(tid))
116     {
117         transportLayerTable.erase(tid);
118     }
119 
120     if (mctpInfoTable.contains(tid))
121     {
122         mctpInfoTable.erase(tid);
123     }
124 
125     return true;
126 }
127 
128 void TerminusManager::discoverMctpTerminus(const MctpInfos& mctpInfos)
129 {
130     queuedMctpInfos.emplace(mctpInfos);
131     if (discoverMctpTerminusTaskHandle.has_value())
132     {
133         auto& [scope, rcOpt] = *discoverMctpTerminusTaskHandle;
134         if (!rcOpt.has_value())
135         {
136             return;
137         }
138         stdexec::sync_wait(scope.on_empty());
139         discoverMctpTerminusTaskHandle.reset();
140     }
141     auto& [scope, rcOpt] = discoverMctpTerminusTaskHandle.emplace();
142     scope.spawn(discoverMctpTerminusTask() |
143                     stdexec::then([&](int rc) { rcOpt.emplace(rc); }),
144                 exec::default_task_context<void>());
145 }
146 
147 auto TerminusManager::findTerminusPtr(const MctpInfo& mctpInfo)
148 {
149     auto foundIter = std::find_if(termini.begin(), termini.end(),
150                                   [&](const auto& terminusPair) {
151         auto terminusMctpInfo = toMctpInfo(terminusPair.first);
152         return (
153             terminusMctpInfo &&
154             (std::get<0>(terminusMctpInfo.value()) == std::get<0>(mctpInfo)) &&
155             (std::get<3>(terminusMctpInfo.value()) == std::get<3>(mctpInfo)));
156     });
157 
158     return foundIter;
159 }
160 
161 exec::task<int> TerminusManager::discoverMctpTerminusTask()
162 {
163     while (!queuedMctpInfos.empty())
164     {
165         if (manager)
166         {
167             co_await manager->beforeDiscoverTerminus();
168         }
169 
170         const MctpInfos& mctpInfos = queuedMctpInfos.front();
171         for (const auto& mctpInfo : mctpInfos)
172         {
173             auto it = findTerminusPtr(mctpInfo);
174             if (it == termini.end())
175             {
176                 co_await initMctpTerminus(mctpInfo);
177             }
178         }
179 
180         if (manager)
181         {
182             co_await manager->afterDiscoverTerminus();
183         }
184 
185         queuedMctpInfos.pop();
186     }
187 
188     co_return PLDM_SUCCESS;
189 }
190 
191 void TerminusManager::removeMctpTerminus(const MctpInfos& mctpInfos)
192 {
193     // remove terminus
194     for (const auto& mctpInfo : mctpInfos)
195     {
196         auto it = findTerminusPtr(mctpInfo);
197         if (it == termini.end())
198         {
199             continue;
200         }
201 
202         unmapTid(it->first);
203         termini.erase(it);
204     }
205 }
206 
207 exec::task<int> TerminusManager::initMctpTerminus(const MctpInfo& mctpInfo)
208 {
209     mctp_eid_t eid = std::get<0>(mctpInfo);
210     pldm_tid_t tid = 0;
211     bool isMapped = false;
212     auto rc = co_await getTidOverMctp(eid, &tid);
213     if (rc != PLDM_SUCCESS)
214     {
215         lg2::error("Failed to Get Terminus ID, error {ERROR}.", "ERROR", rc);
216         co_return PLDM_ERROR;
217     }
218 
219     if (tid == PLDM_TID_RESERVED)
220     {
221         lg2::error("Terminus responses the reserved {TID}.", "TID", tid);
222         co_return PLDM_ERROR;
223     }
224 
225     /* Terminus already has TID */
226     if (tid != PLDM_TID_UNASSIGNED)
227     {
228         /* TID is used by one discovered terminus */
229         auto it = termini.find(tid);
230         if (it != termini.end())
231         {
232             auto terminusMctpInfo = toMctpInfo(it->first);
233             /* The discovered terminus has the same MCTP Info */
234             if (terminusMctpInfo &&
235                 (std::get<0>(terminusMctpInfo.value()) ==
236                  std::get<0>(mctpInfo)) &&
237                 (std::get<3>(terminusMctpInfo.value()) ==
238                  std::get<3>(mctpInfo)))
239             {
240                 co_return PLDM_SUCCESS;
241             }
242             else
243             {
244                 /* ToDo:
245                  * Maybe the terminus supports multiple medium interfaces
246                  * Or the TID is used by other terminus.
247                  * Check the UUID to confirm.
248                  */
249                 isMapped = false;
250             }
251         }
252         /* Use the terminus TID for mapping */
253         else
254         {
255             auto mappedTid = storeTerminusInfo(mctpInfo, tid);
256             if (!mappedTid)
257             {
258                 lg2::error("Failed to store Terminus Info for terminus {TID}.",
259                            "TID", tid);
260                 co_return PLDM_ERROR;
261             }
262             isMapped = true;
263         }
264     }
265 
266     if (!isMapped)
267     {
268         // Assigning a tid. If it has been mapped, mapTid()
269         // returns the tid assigned before.
270         auto mappedTid = mapTid(mctpInfo);
271         if (!mappedTid)
272         {
273             lg2::error("Failed to store Terminus Info for terminus {TID}.",
274                        "TID", tid);
275             co_return PLDM_ERROR;
276         }
277 
278         tid = mappedTid.value();
279         rc = co_await setTidOverMctp(eid, tid);
280         if (rc != PLDM_SUCCESS)
281         {
282             lg2::error("Failed to Set terminus TID, error{ERROR}.", "ERROR",
283                        rc);
284             unmapTid(tid);
285             co_return rc;
286         }
287 
288         if (rc != PLDM_SUCCESS && rc != PLDM_ERROR_UNSUPPORTED_PLDM_CMD)
289         {
290             lg2::error("Terminus {TID} does not support SetTID command.", "TID",
291                        tid);
292             unmapTid(tid);
293             co_return rc;
294         }
295 
296         if (termini.contains(tid))
297         {
298             // the terminus has been discovered before
299             co_return PLDM_SUCCESS;
300         }
301     }
302     /* Discovery the mapped terminus */
303     uint64_t supportedTypes = 0;
304     rc = co_await getPLDMTypes(tid, supportedTypes);
305     if (rc)
306     {
307         lg2::error("Failed to Get PLDM Types for terminus {TID}, error {ERROR}",
308                    "TID", tid, "ERROR", rc);
309         co_return PLDM_ERROR;
310     }
311 
312     try
313     {
314         termini[tid] = std::make_shared<Terminus>(tid, supportedTypes);
315     }
316     catch (const sdbusplus::exception_t& e)
317     {
318         lg2::error("Failed to create terminus manager for terminus {TID}",
319                    "TID", tid);
320         co_return PLDM_ERROR;
321     }
322 
323     uint8_t type = PLDM_BASE;
324     auto size = PLDM_MAX_TYPES * (PLDM_MAX_CMDS_PER_TYPE / 8);
325     std::vector<uint8_t> pldmCmds(size);
326     while ((type < PLDM_MAX_TYPES))
327     {
328         if (!termini[tid]->doesSupportType(type))
329         {
330             type++;
331             continue;
332         }
333         std::vector<bitfield8_t> cmds(PLDM_MAX_CMDS_PER_TYPE / 8);
334         auto rc = co_await getPLDMCommands(tid, type, cmds.data());
335         if (rc)
336         {
337             lg2::error(
338                 "Failed to Get PLDM Commands for terminus {TID}, error {ERROR}",
339                 "TID", tid, "ERROR", rc);
340         }
341 
342         for (size_t i = 0; i < cmds.size(); i++)
343         {
344             auto idx = type * (PLDM_MAX_CMDS_PER_TYPE / 8) + i;
345             if (idx >= pldmCmds.size())
346             {
347                 lg2::error(
348                     "Calculated index {IDX} out of bounds for pldmCmds, type {TYPE}, command index {CMD_IDX}",
349                     "IDX", idx, "TYPE", type, "CMD_IDX", i);
350                 continue;
351             }
352             pldmCmds[idx] = cmds[i].byte;
353         }
354         type++;
355     }
356     termini[tid]->setSupportedCommands(pldmCmds);
357 
358     co_return PLDM_SUCCESS;
359 }
360 
361 exec::task<int>
362     TerminusManager::sendRecvPldmMsgOverMctp(mctp_eid_t eid, Request& request,
363                                              const pldm_msg** responseMsg,
364                                              size_t* responseLen)
365 {
366     try
367     {
368         std::tie(*responseMsg, *responseLen) =
369             co_await handler.sendRecvMsg(eid, std::move(request));
370         co_return PLDM_SUCCESS;
371     }
372     catch (const sdbusplus::exception_t& e)
373     {
374         lg2::error(
375             "Send and Receive PLDM message over MCTP failed with error - {ERROR}.",
376             "ERROR", e);
377         co_return PLDM_ERROR;
378     }
379     catch (const int& rc)
380     {
381         lg2::error("sendRecvPldmMsgOverMctp failed. rc={RC}", "RC", rc);
382         co_return PLDM_ERROR;
383     }
384 }
385 
386 exec::task<int> TerminusManager::getTidOverMctp(mctp_eid_t eid, pldm_tid_t* tid)
387 {
388     auto instanceId = instanceIdDb.next(eid);
389     Request request(sizeof(pldm_msg_hdr));
390     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
391     auto rc = encode_get_tid_req(instanceId, requestMsg);
392     if (rc)
393     {
394         instanceIdDb.free(eid, instanceId);
395         lg2::error(
396             "Failed to encode request GetTID for endpoint ID {EID}, error {RC} ",
397             "EID", eid, "RC", rc);
398         co_return rc;
399     }
400 
401     const pldm_msg* responseMsg = nullptr;
402     size_t responseLen = 0;
403     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
404                                           &responseLen);
405     if (rc)
406     {
407         lg2::error("Failed to send GetTID for Endpoint {EID}, error {RC}",
408                    "EID", eid, "RC", rc);
409         co_return rc;
410     }
411 
412     uint8_t completionCode = 0;
413     rc = decode_get_tid_resp(responseMsg, responseLen, &completionCode, tid);
414     if (rc)
415     {
416         lg2::error(
417             "Failed to decode response GetTID for Endpoint ID {EID}, error {RC} ",
418             "EID", eid, "RC", rc);
419         co_return rc;
420     }
421 
422     if (completionCode != PLDM_SUCCESS)
423     {
424         lg2::error("Error : GetTID for Endpoint ID {EID}, complete code {CC}.",
425                    "EID", eid, "CC", completionCode);
426         co_return rc;
427     }
428 
429     co_return completionCode;
430 }
431 
432 exec::task<int> TerminusManager::setTidOverMctp(mctp_eid_t eid, pldm_tid_t tid)
433 {
434     auto instanceId = instanceIdDb.next(eid);
435     Request request(sizeof(pldm_msg_hdr) + sizeof(pldm_set_tid_req));
436     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
437     auto rc = encode_set_tid_req(instanceId, tid, requestMsg);
438     if (rc)
439     {
440         instanceIdDb.free(eid, instanceId);
441         lg2::error(
442             "Failed to encode request SetTID for endpoint ID {EID}, error {RC} ",
443             "EID", eid, "RC", rc);
444         co_return rc;
445     }
446 
447     const pldm_msg* responseMsg = nullptr;
448     size_t responseLen = 0;
449     rc = co_await sendRecvPldmMsgOverMctp(eid, request, &responseMsg,
450                                           &responseLen);
451     if (rc)
452     {
453         lg2::error("Failed to send SetTID for Endpoint {EID}, error {RC}",
454                    "EID", eid, "RC", rc);
455         co_return rc;
456     }
457 
458     if (responseMsg == NULL || responseLen != PLDM_SET_TID_RESP_BYTES)
459     {
460         lg2::error(
461             "Failed to decode response SetTID for Endpoint ID {EID}, error {RC} ",
462             "EID", eid, "RC", rc);
463         co_return PLDM_ERROR_INVALID_LENGTH;
464     }
465 
466     co_return responseMsg->payload[0];
467 }
468 
469 exec::task<int> TerminusManager::getPLDMTypes(pldm_tid_t tid,
470                                               uint64_t& supportedTypes)
471 {
472     Request request(sizeof(pldm_msg_hdr));
473     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
474     auto rc = encode_get_types_req(0, requestMsg);
475     if (rc)
476     {
477         lg2::error(
478             "Failed to encode request getPLDMTypes for terminus ID {TID}, error {RC} ",
479             "TID", tid, "RC", rc);
480         co_return rc;
481     }
482 
483     const pldm_msg* responseMsg = nullptr;
484     size_t responseLen = 0;
485 
486     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
487     if (rc)
488     {
489         lg2::error("Failed to send GetPLDMTypes for terminus {TID}, error {RC}",
490                    "TID", tid, "RC", rc);
491         co_return rc;
492     }
493 
494     uint8_t completionCode = 0;
495     bitfield8_t* types = reinterpret_cast<bitfield8_t*>(&supportedTypes);
496     rc = decode_get_types_resp(responseMsg, responseLen, &completionCode,
497                                types);
498     if (rc)
499     {
500         lg2::error(
501             "Failed to decode response GetPLDMTypes for terminus ID {TID}, error {RC} ",
502             "TID", tid, "RC", rc);
503         co_return rc;
504     }
505 
506     if (completionCode != PLDM_SUCCESS)
507     {
508         lg2::error(
509             "Error : GetPLDMTypes for terminus ID {TID}, complete code {CC}.",
510             "TID", tid, "CC", completionCode);
511         co_return rc;
512     }
513     co_return completionCode;
514 }
515 
516 exec::task<int> TerminusManager::getPLDMCommands(pldm_tid_t tid, uint8_t type,
517                                                  bitfield8_t* supportedCmds)
518 {
519     Request request(sizeof(pldm_msg_hdr) + PLDM_GET_COMMANDS_REQ_BYTES);
520     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
521     ver32_t version{0xFF, 0xFF, 0xFF, 0xFF};
522 
523     auto rc = encode_get_commands_req(0, type, version, requestMsg);
524     if (rc)
525     {
526         lg2::error(
527             "Failed to encode request GetPLDMCommands for terminus ID {TID}, error {RC} ",
528             "TID", tid, "RC", rc);
529         co_return rc;
530     }
531 
532     const pldm_msg* responseMsg = nullptr;
533     size_t responseLen = 0;
534 
535     rc = co_await sendRecvPldmMsg(tid, request, &responseMsg, &responseLen);
536     if (rc)
537     {
538         lg2::error(
539             "Failed to send GetPLDMCommands message for terminus {TID}, error {RC}",
540             "TID", tid, "RC", rc);
541         co_return rc;
542     }
543 
544     /* Process response */
545     uint8_t completionCode = 0;
546     rc = decode_get_commands_resp(responseMsg, responseLen, &completionCode,
547                                   supportedCmds);
548     if (rc)
549     {
550         lg2::error(
551             "Failed to decode response GetPLDMCommands for terminus ID {TID}, error {RC} ",
552             "TID", tid, "RC", rc);
553         co_return rc;
554     }
555 
556     if (completionCode != PLDM_SUCCESS)
557     {
558         lg2::error(
559             "Error : GetPLDMCommands for terminus ID {TID}, complete code {CC}.",
560             "TID", tid, "CC", completionCode);
561         co_return rc;
562     }
563 
564     co_return completionCode;
565 }
566 
567 exec::task<int> TerminusManager::sendRecvPldmMsg(pldm_tid_t tid,
568                                                  Request& request,
569                                                  const pldm_msg** responseMsg,
570                                                  size_t* responseLen)
571 {
572     /**
573      * Size of tidPool is `std::numeric_limits<pldm_tid_t>::max() + 1`
574      * tidPool[i] always exist
575      */
576     if (!tidPool[tid])
577     {
578         co_return PLDM_ERROR_NOT_READY;
579     }
580 
581     if (!transportLayerTable.contains(tid))
582     {
583         co_return PLDM_ERROR_NOT_READY;
584     }
585 
586     if (transportLayerTable[tid] != SupportedTransportLayer::MCTP)
587     {
588         co_return PLDM_ERROR_NOT_READY;
589     }
590 
591     auto mctpInfo = toMctpInfo(tid);
592     if (!mctpInfo.has_value())
593     {
594         co_return PLDM_ERROR_NOT_READY;
595     }
596 
597     auto eid = std::get<0>(mctpInfo.value());
598     auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
599     requestMsg->hdr.instance_id = instanceIdDb.next(eid);
600     auto rc = co_await sendRecvPldmMsgOverMctp(eid, request, responseMsg,
601                                                responseLen);
602 
603     if (responseMsg == nullptr || !responseLen)
604     {
605         co_return PLDM_ERROR_INVALID_DATA;
606     }
607 
608     co_return rc;
609 }
610 
611 } // namespace platform_mc
612 } // namespace pldm
613