xref: /openbmc/bmcweb/http/routing.hpp (revision cfe3bc0a)
1 #pragma once
2 
3 #include "async_resp.hpp"
4 #include "common.hpp"
5 #include "dbus_privileges.hpp"
6 #include "dbus_utility.hpp"
7 #include "error_messages.hpp"
8 #include "http_request.hpp"
9 #include "http_response.hpp"
10 #include "logging.hpp"
11 #include "privileges.hpp"
12 #include "routing/baserule.hpp"
13 #include "routing/dynamicrule.hpp"
14 #include "routing/sserule.hpp"
15 #include "routing/taggedrule.hpp"
16 #include "routing/websocketrule.hpp"
17 #include "sessions.hpp"
18 #include "utility.hpp"
19 #include "utils/dbus_utils.hpp"
20 #include "verb.hpp"
21 #include "websocket.hpp"
22 
23 #include <boost/beast/ssl/ssl_stream.hpp>
24 #include <boost/container/flat_map.hpp>
25 #include <boost/url/format.hpp>
26 #include <sdbusplus/unpack_properties.hpp>
27 
28 #include <cerrno>
29 #include <cstdint>
30 #include <cstdlib>
31 #include <limits>
32 #include <memory>
33 #include <optional>
34 #include <tuple>
35 #include <utility>
36 #include <vector>
37 
38 namespace crow
39 {
40 
41 class Trie
42 {
43   public:
44     struct Node
45     {
46         unsigned ruleIndex{};
47         std::array<size_t, static_cast<size_t>(ParamType::MAX)>
48             paramChildrens{};
49         using ChildMap = boost::container::flat_map<
50             std::string, unsigned, std::less<>,
51             std::vector<std::pair<std::string, unsigned>>>;
52         ChildMap children;
53 
54         bool isSimpleNode() const
55         {
56             return ruleIndex == 0 &&
57                    std::all_of(std::begin(paramChildrens),
58                                std::end(paramChildrens),
59                                [](size_t x) { return x == 0U; });
60         }
61     };
62 
63     Trie() : nodes(1) {}
64 
65   private:
66     void optimizeNode(Node* node)
67     {
68         for (size_t x : node->paramChildrens)
69         {
70             if (x == 0U)
71             {
72                 continue;
73             }
74             Node* child = &nodes[x];
75             optimizeNode(child);
76         }
77         if (node->children.empty())
78         {
79             return;
80         }
81         bool mergeWithChild = true;
82         for (const Node::ChildMap::value_type& kv : node->children)
83         {
84             Node* child = &nodes[kv.second];
85             if (!child->isSimpleNode())
86             {
87                 mergeWithChild = false;
88                 break;
89             }
90         }
91         if (mergeWithChild)
92         {
93             Node::ChildMap merged;
94             for (const Node::ChildMap::value_type& kv : node->children)
95             {
96                 Node* child = &nodes[kv.second];
97                 for (const Node::ChildMap::value_type& childKv :
98                      child->children)
99                 {
100                     merged[kv.first + childKv.first] = childKv.second;
101                 }
102             }
103             node->children = std::move(merged);
104             optimizeNode(node);
105         }
106         else
107         {
108             for (const Node::ChildMap::value_type& kv : node->children)
109             {
110                 Node* child = &nodes[kv.second];
111                 optimizeNode(child);
112             }
113         }
114     }
115 
116     void optimize()
117     {
118         optimizeNode(head());
119     }
120 
121   public:
122     void validate()
123     {
124         optimize();
125     }
126 
127     void findRouteIndexes(const std::string& reqUrl,
128                           std::vector<unsigned>& routeIndexes,
129                           const Node* node = nullptr, unsigned pos = 0) const
130     {
131         if (node == nullptr)
132         {
133             node = head();
134         }
135         for (const Node::ChildMap::value_type& kv : node->children)
136         {
137             const std::string& fragment = kv.first;
138             const Node* child = &nodes[kv.second];
139             if (pos >= reqUrl.size())
140             {
141                 if (child->ruleIndex != 0 && fragment != "/")
142                 {
143                     routeIndexes.push_back(child->ruleIndex);
144                 }
145                 findRouteIndexes(reqUrl, routeIndexes, child,
146                                  static_cast<unsigned>(pos + fragment.size()));
147             }
148             else
149             {
150                 if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
151                 {
152                     findRouteIndexes(
153                         reqUrl, routeIndexes, child,
154                         static_cast<unsigned>(pos + fragment.size()));
155                 }
156             }
157         }
158     }
159 
160     std::pair<unsigned, std::vector<std::string>>
161         find(const std::string_view reqUrl, const Node* node = nullptr,
162              size_t pos = 0, std::vector<std::string>* params = nullptr) const
163     {
164         std::vector<std::string> empty;
165         if (params == nullptr)
166         {
167             params = &empty;
168         }
169 
170         unsigned found{};
171         std::vector<std::string> matchParams;
172 
173         if (node == nullptr)
174         {
175             node = head();
176         }
177         if (pos == reqUrl.size())
178         {
179             return {node->ruleIndex, *params};
180         }
181 
182         auto updateFound =
183             [&found,
184              &matchParams](std::pair<unsigned, std::vector<std::string>>& ret) {
185             if (ret.first != 0U && (found == 0U || found > ret.first))
186             {
187                 found = ret.first;
188                 matchParams = std::move(ret.second);
189             }
190         };
191 
192         if (node->paramChildrens[static_cast<size_t>(ParamType::STRING)] != 0U)
193         {
194             size_t epos = pos;
195             for (; epos < reqUrl.size(); epos++)
196             {
197                 if (reqUrl[epos] == '/')
198                 {
199                     break;
200                 }
201             }
202 
203             if (epos != pos)
204             {
205                 params->emplace_back(reqUrl.substr(pos, epos - pos));
206                 std::pair<unsigned, std::vector<std::string>> ret =
207                     find(reqUrl,
208                          &nodes[node->paramChildrens[static_cast<size_t>(
209                              ParamType::STRING)]],
210                          epos, params);
211                 updateFound(ret);
212                 params->pop_back();
213             }
214         }
215 
216         if (node->paramChildrens[static_cast<size_t>(ParamType::PATH)] != 0U)
217         {
218             size_t epos = reqUrl.size();
219 
220             if (epos != pos)
221             {
222                 params->emplace_back(reqUrl.substr(pos, epos - pos));
223                 std::pair<unsigned, std::vector<std::string>> ret =
224                     find(reqUrl,
225                          &nodes[node->paramChildrens[static_cast<size_t>(
226                              ParamType::PATH)]],
227                          epos, params);
228                 updateFound(ret);
229                 params->pop_back();
230             }
231         }
232 
233         for (const Node::ChildMap::value_type& kv : node->children)
234         {
235             const std::string& fragment = kv.first;
236             const Node* child = &nodes[kv.second];
237 
238             if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
239             {
240                 std::pair<unsigned, std::vector<std::string>> ret =
241                     find(reqUrl, child, pos + fragment.size(), params);
242                 updateFound(ret);
243             }
244         }
245 
246         return {found, matchParams};
247     }
248 
249     void add(const std::string& url, unsigned ruleIndex)
250     {
251         size_t idx = 0;
252 
253         for (unsigned i = 0; i < url.size(); i++)
254         {
255             char c = url[i];
256             if (c == '<')
257             {
258                 constexpr static std::array<
259                     std::pair<ParamType, std::string_view>, 3>
260                     paramTraits = {{
261                         {ParamType::STRING, "<str>"},
262                         {ParamType::STRING, "<string>"},
263                         {ParamType::PATH, "<path>"},
264                     }};
265 
266                 for (const std::pair<ParamType, std::string_view>& x :
267                      paramTraits)
268                 {
269                     if (url.compare(i, x.second.size(), x.second) == 0)
270                     {
271                         size_t index = static_cast<size_t>(x.first);
272                         if (nodes[idx].paramChildrens[index] == 0U)
273                         {
274                             unsigned newNodeIdx = newNode();
275                             nodes[idx].paramChildrens[index] = newNodeIdx;
276                         }
277                         idx = nodes[idx].paramChildrens[index];
278                         i += static_cast<unsigned>(x.second.size());
279                         break;
280                     }
281                 }
282 
283                 i--;
284             }
285             else
286             {
287                 std::string piece(&c, 1);
288                 if (nodes[idx].children.count(piece) == 0U)
289                 {
290                     unsigned newNodeIdx = newNode();
291                     nodes[idx].children.emplace(piece, newNodeIdx);
292                 }
293                 idx = nodes[idx].children[piece];
294             }
295         }
296         if (nodes[idx].ruleIndex != 0U)
297         {
298             throw std::runtime_error("handler already exists for " + url);
299         }
300         nodes[idx].ruleIndex = ruleIndex;
301     }
302 
303   private:
304     void debugNodePrint(Node* n, size_t level)
305     {
306         for (size_t i = 0; i < static_cast<size_t>(ParamType::MAX); i++)
307         {
308             if (n->paramChildrens[i] != 0U)
309             {
310                 BMCWEB_LOG_DEBUG << std::string(
311                     2U * level, ' ') /*<< "("<<n->paramChildrens[i]<<") "*/;
312                 switch (static_cast<ParamType>(i))
313                 {
314                     case ParamType::STRING:
315                         BMCWEB_LOG_DEBUG << "<str>";
316                         break;
317                     case ParamType::PATH:
318                         BMCWEB_LOG_DEBUG << "<path>";
319                         break;
320                     case ParamType::MAX:
321                         BMCWEB_LOG_DEBUG << "<ERROR>";
322                         break;
323                 }
324 
325                 debugNodePrint(&nodes[n->paramChildrens[i]], level + 1);
326             }
327         }
328         for (const Node::ChildMap::value_type& kv : n->children)
329         {
330             BMCWEB_LOG_DEBUG
331                 << std::string(2U * level, ' ') /*<< "(" << kv.second << ") "*/
332                 << kv.first;
333             debugNodePrint(&nodes[kv.second], level + 1);
334         }
335     }
336 
337   public:
338     void debugPrint()
339     {
340         debugNodePrint(head(), 0U);
341     }
342 
343   private:
344     const Node* head() const
345     {
346         return &nodes.front();
347     }
348 
349     Node* head()
350     {
351         return &nodes.front();
352     }
353 
354     unsigned newNode()
355     {
356         nodes.resize(nodes.size() + 1);
357         return static_cast<unsigned>(nodes.size() - 1);
358     }
359 
360     std::vector<Node> nodes;
361 };
362 
363 class Router
364 {
365   public:
366     Router() = default;
367 
368     DynamicRule& newRuleDynamic(const std::string& rule)
369     {
370         std::unique_ptr<DynamicRule> ruleObject =
371             std::make_unique<DynamicRule>(rule);
372         DynamicRule* ptr = ruleObject.get();
373         allRules.emplace_back(std::move(ruleObject));
374 
375         return *ptr;
376     }
377 
378     template <uint64_t N>
379     auto& newRuleTagged(const std::string& rule)
380     {
381         constexpr size_t numArgs = utility::numArgsFromTag(N);
382         if constexpr (numArgs == 0)
383         {
384             using RuleT = TaggedRule<>;
385             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
386             RuleT* ptr = ruleObject.get();
387             allRules.emplace_back(std::move(ruleObject));
388             return *ptr;
389         }
390         else if constexpr (numArgs == 1)
391         {
392             using RuleT = TaggedRule<std::string>;
393             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
394             RuleT* ptr = ruleObject.get();
395             allRules.emplace_back(std::move(ruleObject));
396             return *ptr;
397         }
398         else if constexpr (numArgs == 2)
399         {
400             using RuleT = TaggedRule<std::string, std::string>;
401             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
402             RuleT* ptr = ruleObject.get();
403             allRules.emplace_back(std::move(ruleObject));
404             return *ptr;
405         }
406         else if constexpr (numArgs == 3)
407         {
408             using RuleT = TaggedRule<std::string, std::string, std::string>;
409             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
410             RuleT* ptr = ruleObject.get();
411             allRules.emplace_back(std::move(ruleObject));
412             return *ptr;
413         }
414         else if constexpr (numArgs == 4)
415         {
416             using RuleT =
417                 TaggedRule<std::string, std::string, std::string, std::string>;
418             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
419             RuleT* ptr = ruleObject.get();
420             allRules.emplace_back(std::move(ruleObject));
421             return *ptr;
422         }
423         else
424         {
425             using RuleT = TaggedRule<std::string, std::string, std::string,
426                                      std::string, std::string>;
427             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
428             RuleT* ptr = ruleObject.get();
429             allRules.emplace_back(std::move(ruleObject));
430             return *ptr;
431         }
432         static_assert(numArgs < 5, "Max number of args supported is 5");
433     }
434 
435     void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
436     {
437         if (ruleObject == nullptr)
438         {
439             return;
440         }
441         for (size_t method = 0, methodBit = 1; method <= methodNotAllowedIndex;
442              method++, methodBit <<= 1)
443         {
444             if ((ruleObject->methodsBitfield & methodBit) > 0U)
445             {
446                 perMethods[method].rules.emplace_back(ruleObject);
447                 perMethods[method].trie.add(
448                     rule, static_cast<unsigned>(
449                               perMethods[method].rules.size() - 1U));
450                 // directory case:
451                 //   request to `/about' url matches `/about/' rule
452                 if (rule.size() > 2 && rule.back() == '/')
453                 {
454                     perMethods[method].trie.add(
455                         rule.substr(0, rule.size() - 1),
456                         static_cast<unsigned>(perMethods[method].rules.size() -
457                                               1));
458                 }
459             }
460         }
461     }
462 
463     void validate()
464     {
465         for (std::unique_ptr<BaseRule>& rule : allRules)
466         {
467             if (rule)
468             {
469                 std::unique_ptr<BaseRule> upgraded = rule->upgrade();
470                 if (upgraded)
471                 {
472                     rule = std::move(upgraded);
473                 }
474                 rule->validate();
475                 internalAddRuleObject(rule->rule, rule.get());
476             }
477         }
478         for (PerMethod& perMethod : perMethods)
479         {
480             perMethod.trie.validate();
481         }
482     }
483 
484     struct FindRoute
485     {
486         BaseRule* rule = nullptr;
487         std::vector<std::string> params;
488     };
489 
490     struct FindRouteResponse
491     {
492         std::string allowHeader;
493         FindRoute route;
494     };
495 
496     FindRoute findRouteByIndex(std::string_view url, size_t index) const
497     {
498         FindRoute route;
499         if (index >= perMethods.size())
500         {
501             BMCWEB_LOG_CRITICAL << "Bad index???";
502             return route;
503         }
504         const PerMethod& perMethod = perMethods[index];
505         std::pair<unsigned, std::vector<std::string>> found =
506             perMethod.trie.find(url);
507         if (found.first >= perMethod.rules.size())
508         {
509             throw std::runtime_error("Trie internal structure corrupted!");
510         }
511         // Found a 404 route, switch that in
512         if (found.first != 0U)
513         {
514             route.rule = perMethod.rules[found.first];
515             route.params = std::move(found.second);
516         }
517         return route;
518     }
519 
520     FindRouteResponse findRoute(Request& req) const
521     {
522         FindRouteResponse findRoute;
523 
524         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
525         if (!verb)
526         {
527             return findRoute;
528         }
529         size_t reqMethodIndex = static_cast<size_t>(*verb);
530         // Check to see if this url exists at any verb
531         for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
532              perMethodIndex++)
533         {
534             // Make sure it's safe to deference the array at that index
535             static_assert(maxVerbIndex <
536                           std::tuple_size_v<decltype(perMethods)>);
537             FindRoute route = findRouteByIndex(req.url().encoded_path(),
538                                                perMethodIndex);
539             if (route.rule == nullptr)
540             {
541                 continue;
542             }
543             if (!findRoute.allowHeader.empty())
544             {
545                 findRoute.allowHeader += ", ";
546             }
547             HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
548             findRoute.allowHeader += httpVerbToString(thisVerb);
549             if (perMethodIndex == reqMethodIndex)
550             {
551                 findRoute.route = route;
552             }
553         }
554         return findRoute;
555     }
556 
557     template <typename Adaptor>
558     void handleUpgrade(Request& req,
559                        const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
560                        Adaptor&& adaptor)
561     {
562         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
563         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
564         {
565             asyncResp->res.result(boost::beast::http::status::not_found);
566             return;
567         }
568         PerMethod& perMethod = perMethods[static_cast<size_t>(*verb)];
569         Trie& trie = perMethod.trie;
570         std::vector<BaseRule*>& rules = perMethod.rules;
571 
572         const std::pair<unsigned, std::vector<std::string>>& found =
573             trie.find(req.url().encoded_path());
574         unsigned ruleIndex = found.first;
575         if (ruleIndex == 0U)
576         {
577             BMCWEB_LOG_DEBUG << "Cannot match rules "
578                              << req.url().encoded_path();
579             asyncResp->res.result(boost::beast::http::status::not_found);
580             return;
581         }
582 
583         if (ruleIndex >= rules.size())
584         {
585             throw std::runtime_error("Trie internal structure corrupted!");
586         }
587 
588         BaseRule& rule = *rules[ruleIndex];
589         size_t methods = rule.getMethods();
590         if ((methods & (1U << static_cast<size_t>(*verb))) == 0)
591         {
592             BMCWEB_LOG_DEBUG
593                 << "Rule found but method mismatch: "
594                 << req.url().encoded_path() << " with " << req.methodString()
595                 << "(" << static_cast<uint32_t>(*verb) << ") / " << methods;
596             asyncResp->res.result(boost::beast::http::status::not_found);
597             return;
598         }
599 
600         BMCWEB_LOG_DEBUG << "Matched rule (upgrade) '" << rule.rule << "' "
601                          << static_cast<uint32_t>(*verb) << " / " << methods;
602 
603         // TODO(ed) This should be able to use std::bind_front, but it doesn't
604         // appear to work with the std::move on adaptor.
605         validatePrivilege(
606             req, asyncResp, rule,
607             [&rule, asyncResp, adaptor(std::forward<Adaptor>(adaptor))](
608                 Request& thisReq) mutable {
609             rule.handleUpgrade(thisReq, asyncResp, std::move(adaptor));
610             });
611     }
612 
613     void handle(Request& req,
614                 const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
615     {
616         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
617         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
618         {
619             asyncResp->res.result(boost::beast::http::status::not_found);
620             return;
621         }
622 
623         FindRouteResponse foundRoute = findRoute(req);
624 
625         if (foundRoute.route.rule == nullptr)
626         {
627             // Couldn't find a normal route with any verb, try looking for a 404
628             // route
629             if (foundRoute.allowHeader.empty())
630             {
631                 foundRoute.route = findRouteByIndex(req.url().encoded_path(),
632                                                     notFoundIndex);
633             }
634             else
635             {
636                 // See if we have a method not allowed (405) handler
637                 foundRoute.route = findRouteByIndex(req.url().encoded_path(),
638                                                     methodNotAllowedIndex);
639             }
640         }
641 
642         // Fill in the allow header if it's valid
643         if (!foundRoute.allowHeader.empty())
644         {
645             asyncResp->res.addHeader(boost::beast::http::field::allow,
646                                      foundRoute.allowHeader);
647         }
648 
649         // If we couldn't find a real route or a 404 route, return a generic
650         // response
651         if (foundRoute.route.rule == nullptr)
652         {
653             if (foundRoute.allowHeader.empty())
654             {
655                 asyncResp->res.result(boost::beast::http::status::not_found);
656             }
657             else
658             {
659                 asyncResp->res.result(
660                     boost::beast::http::status::method_not_allowed);
661             }
662             return;
663         }
664 
665         BaseRule& rule = *foundRoute.route.rule;
666         std::vector<std::string> params = std::move(foundRoute.route.params);
667 
668         BMCWEB_LOG_DEBUG << "Matched rule '" << rule.rule << "' "
669                          << static_cast<uint32_t>(*verb) << " / "
670                          << rule.getMethods();
671 
672         if (req.session == nullptr)
673         {
674             rule.handle(req, asyncResp, params);
675             return;
676         }
677         validatePrivilege(req, asyncResp, rule,
678                           [&rule, asyncResp, params](Request& thisReq) mutable {
679             rule.handle(thisReq, asyncResp, params);
680         });
681     }
682 
683     void debugPrint()
684     {
685         for (size_t i = 0; i < perMethods.size(); i++)
686         {
687             BMCWEB_LOG_DEBUG << boost::beast::http::to_string(
688                 static_cast<boost::beast::http::verb>(i));
689             perMethods[i].trie.debugPrint();
690         }
691     }
692 
693     std::vector<const std::string*> getRoutes(const std::string& parent)
694     {
695         std::vector<const std::string*> ret;
696 
697         for (const PerMethod& pm : perMethods)
698         {
699             std::vector<unsigned> x;
700             pm.trie.findRouteIndexes(parent, x);
701             for (unsigned index : x)
702             {
703                 ret.push_back(&pm.rules[index]->rule);
704             }
705         }
706         return ret;
707     }
708 
709   private:
710     struct PerMethod
711     {
712         std::vector<BaseRule*> rules;
713         Trie trie;
714         // rule index 0 has special meaning; preallocate it to avoid
715         // duplication.
716         PerMethod() : rules(1) {}
717     };
718 
719     std::array<PerMethod, methodNotAllowedIndex + 1> perMethods;
720     std::vector<std::unique_ptr<BaseRule>> allRules;
721 };
722 } // namespace crow
723