xref: /openbmc/bmcweb/http/routing.hpp (revision 1b8b02a4)
1 #pragma once
2 
3 #include "async_resp.hpp"
4 #include "common.hpp"
5 #include "dbus_privileges.hpp"
6 #include "dbus_utility.hpp"
7 #include "error_messages.hpp"
8 #include "http_request.hpp"
9 #include "http_response.hpp"
10 #include "logging.hpp"
11 #include "privileges.hpp"
12 #include "routing/baserule.hpp"
13 #include "routing/dynamicrule.hpp"
14 #include "routing/sserule.hpp"
15 #include "routing/taggedrule.hpp"
16 #include "routing/websocketrule.hpp"
17 #include "sessions.hpp"
18 #include "utility.hpp"
19 #include "utils/dbus_utils.hpp"
20 #include "verb.hpp"
21 #include "websocket.hpp"
22 
23 #include <boost/beast/ssl/ssl_stream.hpp>
24 #include <boost/container/flat_map.hpp>
25 #include <boost/url/format.hpp>
26 #include <sdbusplus/unpack_properties.hpp>
27 
28 #include <cerrno>
29 #include <cstdint>
30 #include <cstdlib>
31 #include <limits>
32 #include <memory>
33 #include <optional>
34 #include <tuple>
35 #include <utility>
36 #include <vector>
37 
38 namespace crow
39 {
40 
41 class Trie
42 {
43   public:
44     struct Node
45     {
46         unsigned ruleIndex{};
47         std::array<size_t, static_cast<size_t>(ParamType::MAX)>
48             paramChildrens{};
49         using ChildMap = boost::container::flat_map<
50             std::string, unsigned, std::less<>,
51             std::vector<std::pair<std::string, unsigned>>>;
52         ChildMap children;
53 
54         bool isSimpleNode() const
55         {
56             return ruleIndex == 0 &&
57                    std::all_of(std::begin(paramChildrens),
58                                std::end(paramChildrens),
59                                [](size_t x) { return x == 0U; });
60         }
61     };
62 
63     Trie() : nodes(1) {}
64 
65   private:
66     void optimizeNode(Node* node)
67     {
68         for (size_t x : node->paramChildrens)
69         {
70             if (x == 0U)
71             {
72                 continue;
73             }
74             Node* child = &nodes[x];
75             optimizeNode(child);
76         }
77         if (node->children.empty())
78         {
79             return;
80         }
81         bool mergeWithChild = true;
82         for (const Node::ChildMap::value_type& kv : node->children)
83         {
84             Node* child = &nodes[kv.second];
85             if (!child->isSimpleNode())
86             {
87                 mergeWithChild = false;
88                 break;
89             }
90         }
91         if (mergeWithChild)
92         {
93             Node::ChildMap merged;
94             for (const Node::ChildMap::value_type& kv : node->children)
95             {
96                 Node* child = &nodes[kv.second];
97                 for (const Node::ChildMap::value_type& childKv :
98                      child->children)
99                 {
100                     merged[kv.first + childKv.first] = childKv.second;
101                 }
102             }
103             node->children = std::move(merged);
104             optimizeNode(node);
105         }
106         else
107         {
108             for (const Node::ChildMap::value_type& kv : node->children)
109             {
110                 Node* child = &nodes[kv.second];
111                 optimizeNode(child);
112             }
113         }
114     }
115 
116     void optimize()
117     {
118         optimizeNode(head());
119     }
120 
121   public:
122     void validate()
123     {
124         optimize();
125     }
126 
127     void findRouteIndexes(const std::string& reqUrl,
128                           std::vector<unsigned>& routeIndexes,
129                           const Node* node = nullptr, unsigned pos = 0) const
130     {
131         if (node == nullptr)
132         {
133             node = head();
134         }
135         for (const Node::ChildMap::value_type& kv : node->children)
136         {
137             const std::string& fragment = kv.first;
138             const Node* child = &nodes[kv.second];
139             if (pos >= reqUrl.size())
140             {
141                 if (child->ruleIndex != 0 && fragment != "/")
142                 {
143                     routeIndexes.push_back(child->ruleIndex);
144                 }
145                 findRouteIndexes(reqUrl, routeIndexes, child,
146                                  static_cast<unsigned>(pos + fragment.size()));
147             }
148             else
149             {
150                 if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
151                 {
152                     findRouteIndexes(
153                         reqUrl, routeIndexes, child,
154                         static_cast<unsigned>(pos + fragment.size()));
155                 }
156             }
157         }
158     }
159 
160     std::pair<unsigned, std::vector<std::string>>
161         find(const std::string_view reqUrl, const Node* node = nullptr,
162              size_t pos = 0, std::vector<std::string>* params = nullptr) const
163     {
164         std::vector<std::string> empty;
165         if (params == nullptr)
166         {
167             params = &empty;
168         }
169 
170         unsigned found{};
171         std::vector<std::string> matchParams;
172 
173         if (node == nullptr)
174         {
175             node = head();
176         }
177         if (pos == reqUrl.size())
178         {
179             return {node->ruleIndex, *params};
180         }
181 
182         auto updateFound =
183             [&found,
184              &matchParams](std::pair<unsigned, std::vector<std::string>>& ret) {
185             if (ret.first != 0U && (found == 0U || found > ret.first))
186             {
187                 found = ret.first;
188                 matchParams = std::move(ret.second);
189             }
190         };
191 
192         if (node->paramChildrens[static_cast<size_t>(ParamType::STRING)] != 0U)
193         {
194             size_t epos = pos;
195             for (; epos < reqUrl.size(); epos++)
196             {
197                 if (reqUrl[epos] == '/')
198                 {
199                     break;
200                 }
201             }
202 
203             if (epos != pos)
204             {
205                 params->emplace_back(reqUrl.substr(pos, epos - pos));
206                 std::pair<unsigned, std::vector<std::string>> ret =
207                     find(reqUrl,
208                          &nodes[node->paramChildrens[static_cast<size_t>(
209                              ParamType::STRING)]],
210                          epos, params);
211                 updateFound(ret);
212                 params->pop_back();
213             }
214         }
215 
216         if (node->paramChildrens[static_cast<size_t>(ParamType::PATH)] != 0U)
217         {
218             size_t epos = reqUrl.size();
219 
220             if (epos != pos)
221             {
222                 params->emplace_back(reqUrl.substr(pos, epos - pos));
223                 std::pair<unsigned, std::vector<std::string>> ret =
224                     find(reqUrl,
225                          &nodes[node->paramChildrens[static_cast<size_t>(
226                              ParamType::PATH)]],
227                          epos, params);
228                 updateFound(ret);
229                 params->pop_back();
230             }
231         }
232 
233         for (const Node::ChildMap::value_type& kv : node->children)
234         {
235             const std::string& fragment = kv.first;
236             const Node* child = &nodes[kv.second];
237 
238             if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
239             {
240                 std::pair<unsigned, std::vector<std::string>> ret =
241                     find(reqUrl, child, pos + fragment.size(), params);
242                 updateFound(ret);
243             }
244         }
245 
246         return {found, matchParams};
247     }
248 
249     void add(const std::string& url, unsigned ruleIndex)
250     {
251         size_t idx = 0;
252 
253         for (unsigned i = 0; i < url.size(); i++)
254         {
255             char c = url[i];
256             if (c == '<')
257             {
258                 constexpr static std::array<
259                     std::pair<ParamType, std::string_view>, 3>
260                     paramTraits = {{
261                         {ParamType::STRING, "<str>"},
262                         {ParamType::STRING, "<string>"},
263                         {ParamType::PATH, "<path>"},
264                     }};
265 
266                 for (const std::pair<ParamType, std::string_view>& x :
267                      paramTraits)
268                 {
269                     if (url.compare(i, x.second.size(), x.second) == 0)
270                     {
271                         size_t index = static_cast<size_t>(x.first);
272                         if (nodes[idx].paramChildrens[index] == 0U)
273                         {
274                             unsigned newNodeIdx = newNode();
275                             nodes[idx].paramChildrens[index] = newNodeIdx;
276                         }
277                         idx = nodes[idx].paramChildrens[index];
278                         i += static_cast<unsigned>(x.second.size());
279                         break;
280                     }
281                 }
282 
283                 i--;
284             }
285             else
286             {
287                 std::string piece(&c, 1);
288                 if (nodes[idx].children.count(piece) == 0U)
289                 {
290                     unsigned newNodeIdx = newNode();
291                     nodes[idx].children.emplace(piece, newNodeIdx);
292                 }
293                 idx = nodes[idx].children[piece];
294             }
295         }
296         if (nodes[idx].ruleIndex != 0U)
297         {
298             throw std::runtime_error("handler already exists for " + url);
299         }
300         nodes[idx].ruleIndex = ruleIndex;
301     }
302 
303   private:
304     void debugNodePrint(Node* n, size_t level)
305     {
306         for (size_t i = 0; i < static_cast<size_t>(ParamType::MAX); i++)
307         {
308             if (n->paramChildrens[i] != 0U)
309             {
310                 BMCWEB_LOG_DEBUG(
311                     "{}({}{}",
312                     std::string(2U * level,
313                                 ' ') /*, n->paramChildrens[i], ") "*/);
314                 switch (static_cast<ParamType>(i))
315                 {
316                     case ParamType::STRING:
317                         BMCWEB_LOG_DEBUG("<str>");
318                         break;
319                     case ParamType::PATH:
320                         BMCWEB_LOG_DEBUG("<path>");
321                         break;
322                     case ParamType::MAX:
323                         BMCWEB_LOG_DEBUG("<ERROR>");
324                         break;
325                 }
326 
327                 debugNodePrint(&nodes[n->paramChildrens[i]], level + 1);
328             }
329         }
330         for (const Node::ChildMap::value_type& kv : n->children)
331         {
332             BMCWEB_LOG_DEBUG("{}({}{}{}",
333                              std::string(2U * level, ' ') /*, kv.second, ") "*/,
334                              kv.first);
335             debugNodePrint(&nodes[kv.second], level + 1);
336         }
337     }
338 
339   public:
340     void debugPrint()
341     {
342         debugNodePrint(head(), 0U);
343     }
344 
345   private:
346     const Node* head() const
347     {
348         return &nodes.front();
349     }
350 
351     Node* head()
352     {
353         return &nodes.front();
354     }
355 
356     unsigned newNode()
357     {
358         nodes.resize(nodes.size() + 1);
359         return static_cast<unsigned>(nodes.size() - 1);
360     }
361 
362     std::vector<Node> nodes;
363 };
364 
365 class Router
366 {
367   public:
368     Router() = default;
369 
370     DynamicRule& newRuleDynamic(const std::string& rule)
371     {
372         std::unique_ptr<DynamicRule> ruleObject =
373             std::make_unique<DynamicRule>(rule);
374         DynamicRule* ptr = ruleObject.get();
375         allRules.emplace_back(std::move(ruleObject));
376 
377         return *ptr;
378     }
379 
380     template <uint64_t N>
381     auto& newRuleTagged(const std::string& rule)
382     {
383         constexpr size_t numArgs = utility::numArgsFromTag(N);
384         if constexpr (numArgs == 0)
385         {
386             using RuleT = TaggedRule<>;
387             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
388             RuleT* ptr = ruleObject.get();
389             allRules.emplace_back(std::move(ruleObject));
390             return *ptr;
391         }
392         else if constexpr (numArgs == 1)
393         {
394             using RuleT = TaggedRule<std::string>;
395             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
396             RuleT* ptr = ruleObject.get();
397             allRules.emplace_back(std::move(ruleObject));
398             return *ptr;
399         }
400         else if constexpr (numArgs == 2)
401         {
402             using RuleT = TaggedRule<std::string, std::string>;
403             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
404             RuleT* ptr = ruleObject.get();
405             allRules.emplace_back(std::move(ruleObject));
406             return *ptr;
407         }
408         else if constexpr (numArgs == 3)
409         {
410             using RuleT = TaggedRule<std::string, std::string, std::string>;
411             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
412             RuleT* ptr = ruleObject.get();
413             allRules.emplace_back(std::move(ruleObject));
414             return *ptr;
415         }
416         else if constexpr (numArgs == 4)
417         {
418             using RuleT =
419                 TaggedRule<std::string, std::string, std::string, std::string>;
420             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
421             RuleT* ptr = ruleObject.get();
422             allRules.emplace_back(std::move(ruleObject));
423             return *ptr;
424         }
425         else
426         {
427             using RuleT = TaggedRule<std::string, std::string, std::string,
428                                      std::string, std::string>;
429             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
430             RuleT* ptr = ruleObject.get();
431             allRules.emplace_back(std::move(ruleObject));
432             return *ptr;
433         }
434         static_assert(numArgs < 5, "Max number of args supported is 5");
435     }
436 
437     void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
438     {
439         if (ruleObject == nullptr)
440         {
441             return;
442         }
443         for (size_t method = 0, methodBit = 1; method <= methodNotAllowedIndex;
444              method++, methodBit <<= 1)
445         {
446             if ((ruleObject->methodsBitfield & methodBit) > 0U)
447             {
448                 perMethods[method].rules.emplace_back(ruleObject);
449                 perMethods[method].trie.add(
450                     rule, static_cast<unsigned>(
451                               perMethods[method].rules.size() - 1U));
452                 // directory case:
453                 //   request to `/about' url matches `/about/' rule
454                 if (rule.size() > 2 && rule.back() == '/')
455                 {
456                     perMethods[method].trie.add(
457                         rule.substr(0, rule.size() - 1),
458                         static_cast<unsigned>(perMethods[method].rules.size() -
459                                               1));
460                 }
461             }
462         }
463     }
464 
465     void validate()
466     {
467         for (std::unique_ptr<BaseRule>& rule : allRules)
468         {
469             if (rule)
470             {
471                 std::unique_ptr<BaseRule> upgraded = rule->upgrade();
472                 if (upgraded)
473                 {
474                     rule = std::move(upgraded);
475                 }
476                 rule->validate();
477                 internalAddRuleObject(rule->rule, rule.get());
478             }
479         }
480         for (PerMethod& perMethod : perMethods)
481         {
482             perMethod.trie.validate();
483         }
484     }
485 
486     struct FindRoute
487     {
488         BaseRule* rule = nullptr;
489         std::vector<std::string> params;
490     };
491 
492     struct FindRouteResponse
493     {
494         std::string allowHeader;
495         FindRoute route;
496     };
497 
498     FindRoute findRouteByIndex(std::string_view url, size_t index) const
499     {
500         FindRoute route;
501         if (index >= perMethods.size())
502         {
503             BMCWEB_LOG_CRITICAL("Bad index???");
504             return route;
505         }
506         const PerMethod& perMethod = perMethods[index];
507         std::pair<unsigned, std::vector<std::string>> found =
508             perMethod.trie.find(url);
509         if (found.first >= perMethod.rules.size())
510         {
511             throw std::runtime_error("Trie internal structure corrupted!");
512         }
513         // Found a 404 route, switch that in
514         if (found.first != 0U)
515         {
516             route.rule = perMethod.rules[found.first];
517             route.params = std::move(found.second);
518         }
519         return route;
520     }
521 
522     FindRouteResponse findRoute(Request& req) const
523     {
524         FindRouteResponse findRoute;
525 
526         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
527         if (!verb)
528         {
529             return findRoute;
530         }
531         size_t reqMethodIndex = static_cast<size_t>(*verb);
532         // Check to see if this url exists at any verb
533         for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
534              perMethodIndex++)
535         {
536             // Make sure it's safe to deference the array at that index
537             static_assert(maxVerbIndex <
538                           std::tuple_size_v<decltype(perMethods)>);
539             FindRoute route = findRouteByIndex(req.url().encoded_path(),
540                                                perMethodIndex);
541             if (route.rule == nullptr)
542             {
543                 continue;
544             }
545             if (!findRoute.allowHeader.empty())
546             {
547                 findRoute.allowHeader += ", ";
548             }
549             HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
550             findRoute.allowHeader += httpVerbToString(thisVerb);
551             if (perMethodIndex == reqMethodIndex)
552             {
553                 findRoute.route = route;
554             }
555         }
556         return findRoute;
557     }
558 
559     template <typename Adaptor>
560     void handleUpgrade(Request& req,
561                        const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
562                        Adaptor&& adaptor)
563     {
564         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
565         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
566         {
567             asyncResp->res.result(boost::beast::http::status::not_found);
568             return;
569         }
570         PerMethod& perMethod = perMethods[static_cast<size_t>(*verb)];
571         Trie& trie = perMethod.trie;
572         std::vector<BaseRule*>& rules = perMethod.rules;
573 
574         const std::pair<unsigned, std::vector<std::string>>& found =
575             trie.find(req.url().encoded_path());
576         unsigned ruleIndex = found.first;
577         if (ruleIndex == 0U)
578         {
579             BMCWEB_LOG_DEBUG("Cannot match rules {}", req.url().encoded_path());
580             asyncResp->res.result(boost::beast::http::status::not_found);
581             return;
582         }
583 
584         if (ruleIndex >= rules.size())
585         {
586             throw std::runtime_error("Trie internal structure corrupted!");
587         }
588 
589         BaseRule& rule = *rules[ruleIndex];
590         size_t methods = rule.getMethods();
591         if ((methods & (1U << static_cast<size_t>(*verb))) == 0)
592         {
593             BMCWEB_LOG_DEBUG(
594                 "Rule found but method mismatch: {} with {}({}) / {}",
595                 req.url().encoded_path(), req.methodString(),
596                 static_cast<uint32_t>(*verb), methods);
597             asyncResp->res.result(boost::beast::http::status::not_found);
598             return;
599         }
600 
601         BMCWEB_LOG_DEBUG("Matched rule (upgrade) '{}' {} / {}", rule.rule,
602                          static_cast<uint32_t>(*verb), methods);
603 
604         // TODO(ed) This should be able to use std::bind_front, but it doesn't
605         // appear to work with the std::move on adaptor.
606         validatePrivilege(
607             req, asyncResp, rule,
608             [&rule, asyncResp, adaptor(std::forward<Adaptor>(adaptor))](
609                 Request& thisReq) mutable {
610             rule.handleUpgrade(thisReq, asyncResp, std::move(adaptor));
611             });
612     }
613 
614     void handle(Request& req,
615                 const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
616     {
617         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
618         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
619         {
620             asyncResp->res.result(boost::beast::http::status::not_found);
621             return;
622         }
623 
624         FindRouteResponse foundRoute = findRoute(req);
625 
626         if (foundRoute.route.rule == nullptr)
627         {
628             // Couldn't find a normal route with any verb, try looking for a 404
629             // route
630             if (foundRoute.allowHeader.empty())
631             {
632                 foundRoute.route = findRouteByIndex(req.url().encoded_path(),
633                                                     notFoundIndex);
634             }
635             else
636             {
637                 // See if we have a method not allowed (405) handler
638                 foundRoute.route = findRouteByIndex(req.url().encoded_path(),
639                                                     methodNotAllowedIndex);
640             }
641         }
642 
643         // Fill in the allow header if it's valid
644         if (!foundRoute.allowHeader.empty())
645         {
646             asyncResp->res.addHeader(boost::beast::http::field::allow,
647                                      foundRoute.allowHeader);
648         }
649 
650         // If we couldn't find a real route or a 404 route, return a generic
651         // response
652         if (foundRoute.route.rule == nullptr)
653         {
654             if (foundRoute.allowHeader.empty())
655             {
656                 asyncResp->res.result(boost::beast::http::status::not_found);
657             }
658             else
659             {
660                 asyncResp->res.result(
661                     boost::beast::http::status::method_not_allowed);
662             }
663             return;
664         }
665 
666         BaseRule& rule = *foundRoute.route.rule;
667         std::vector<std::string> params = std::move(foundRoute.route.params);
668 
669         BMCWEB_LOG_DEBUG("Matched rule '{}' {} / {}", rule.rule,
670                          static_cast<uint32_t>(*verb), rule.getMethods());
671 
672         if (req.session == nullptr)
673         {
674             rule.handle(req, asyncResp, params);
675             return;
676         }
677         validatePrivilege(req, asyncResp, rule,
678                           [&rule, asyncResp, params](Request& thisReq) mutable {
679             rule.handle(thisReq, asyncResp, params);
680         });
681     }
682 
683     void debugPrint()
684     {
685         for (size_t i = 0; i < perMethods.size(); i++)
686         {
687             BMCWEB_LOG_DEBUG("{}",
688                              boost::beast::http::to_string(
689                                  static_cast<boost::beast::http::verb>(i)));
690             perMethods[i].trie.debugPrint();
691         }
692     }
693 
694     std::vector<const std::string*> getRoutes(const std::string& parent)
695     {
696         std::vector<const std::string*> ret;
697 
698         for (const PerMethod& pm : perMethods)
699         {
700             std::vector<unsigned> x;
701             pm.trie.findRouteIndexes(parent, x);
702             for (unsigned index : x)
703             {
704                 ret.push_back(&pm.rules[index]->rule);
705             }
706         }
707         return ret;
708     }
709 
710   private:
711     struct PerMethod
712     {
713         std::vector<BaseRule*> rules;
714         Trie trie;
715         // rule index 0 has special meaning; preallocate it to avoid
716         // duplication.
717         PerMethod() : rules(1) {}
718     };
719 
720     std::array<PerMethod, methodNotAllowedIndex + 1> perMethods;
721     std::vector<std::unique_ptr<BaseRule>> allRules;
722 };
723 } // namespace crow
724