xref: /openbmc/bmcweb/http/routing.hpp (revision 994fd86a3f6649a820f66313765e85e762ad105a)
1 #pragma once
2 
3 #include "async_resp.hpp"
4 #include "common.hpp"
5 #include "dbus_privileges.hpp"
6 #include "dbus_utility.hpp"
7 #include "error_messages.hpp"
8 #include "http_request.hpp"
9 #include "http_response.hpp"
10 #include "logging.hpp"
11 #include "privileges.hpp"
12 #include "routing/baserule.hpp"
13 #include "routing/dynamicrule.hpp"
14 #include "routing/sserule.hpp"
15 #include "routing/taggedrule.hpp"
16 #include "routing/websocketrule.hpp"
17 #include "sessions.hpp"
18 #include "utility.hpp"
19 #include "utils/dbus_utils.hpp"
20 #include "verb.hpp"
21 #include "websocket.hpp"
22 
23 #include <boost/beast/ssl/ssl_stream.hpp>
24 #include <boost/container/flat_map.hpp>
25 #include <boost/url/format.hpp>
26 #include <sdbusplus/unpack_properties.hpp>
27 
28 #include <cerrno>
29 #include <cstdint>
30 #include <cstdlib>
31 #include <limits>
32 #include <memory>
33 #include <optional>
34 #include <tuple>
35 #include <utility>
36 #include <vector>
37 
38 namespace crow
39 {
40 
41 class Trie
42 {
43   public:
44     struct Node
45     {
46         unsigned ruleIndex{};
47         std::array<size_t, static_cast<size_t>(ParamType::MAX)>
48             paramChildrens{};
49         using ChildMap = boost::container::flat_map<
50             std::string, unsigned, std::less<>,
51             std::vector<std::pair<std::string, unsigned>>>;
52         ChildMap children;
53 
54         bool isSimpleNode() const
55         {
56             return ruleIndex == 0 &&
57                    std::all_of(std::begin(paramChildrens),
58                                std::end(paramChildrens),
59                                [](size_t x) { return x == 0U; });
60         }
61     };
62 
63     Trie() : nodes(1) {}
64 
65   private:
66     void optimizeNode(Node* node)
67     {
68         for (size_t x : node->paramChildrens)
69         {
70             if (x == 0U)
71             {
72                 continue;
73             }
74             Node* child = &nodes[x];
75             optimizeNode(child);
76         }
77         if (node->children.empty())
78         {
79             return;
80         }
81         bool mergeWithChild = true;
82         for (const Node::ChildMap::value_type& kv : node->children)
83         {
84             Node* child = &nodes[kv.second];
85             if (!child->isSimpleNode())
86             {
87                 mergeWithChild = false;
88                 break;
89             }
90         }
91         if (mergeWithChild)
92         {
93             Node::ChildMap merged;
94             for (const Node::ChildMap::value_type& kv : node->children)
95             {
96                 Node* child = &nodes[kv.second];
97                 for (const Node::ChildMap::value_type& childKv :
98                      child->children)
99                 {
100                     merged[kv.first + childKv.first] = childKv.second;
101                 }
102             }
103             node->children = std::move(merged);
104             optimizeNode(node);
105         }
106         else
107         {
108             for (const Node::ChildMap::value_type& kv : node->children)
109             {
110                 Node* child = &nodes[kv.second];
111                 optimizeNode(child);
112             }
113         }
114     }
115 
116     void optimize()
117     {
118         optimizeNode(head());
119     }
120 
121   public:
122     void validate()
123     {
124         optimize();
125     }
126 
127     void findRouteIndexes(const std::string& reqUrl,
128                           std::vector<unsigned>& routeIndexes,
129                           const Node* node = nullptr, unsigned pos = 0) const
130     {
131         if (node == nullptr)
132         {
133             node = head();
134         }
135         for (const Node::ChildMap::value_type& kv : node->children)
136         {
137             const std::string& fragment = kv.first;
138             const Node* child = &nodes[kv.second];
139             if (pos >= reqUrl.size())
140             {
141                 if (child->ruleIndex != 0 && fragment != "/")
142                 {
143                     routeIndexes.push_back(child->ruleIndex);
144                 }
145                 findRouteIndexes(reqUrl, routeIndexes, child,
146                                  static_cast<unsigned>(pos + fragment.size()));
147             }
148             else
149             {
150                 if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
151                 {
152                     findRouteIndexes(
153                         reqUrl, routeIndexes, child,
154                         static_cast<unsigned>(pos + fragment.size()));
155                 }
156             }
157         }
158     }
159 
160     std::pair<unsigned, std::vector<std::string>>
161         find(const std::string_view reqUrl, const Node* node = nullptr,
162              size_t pos = 0, std::vector<std::string>* params = nullptr) const
163     {
164         std::vector<std::string> empty;
165         if (params == nullptr)
166         {
167             params = &empty;
168         }
169 
170         unsigned found{};
171         std::vector<std::string> matchParams;
172 
173         if (node == nullptr)
174         {
175             node = head();
176         }
177         if (pos == reqUrl.size())
178         {
179             return {node->ruleIndex, *params};
180         }
181 
182         auto updateFound =
183             [&found,
184              &matchParams](std::pair<unsigned, std::vector<std::string>>& ret) {
185             if (ret.first != 0U && (found == 0U || found > ret.first))
186             {
187                 found = ret.first;
188                 matchParams = std::move(ret.second);
189             }
190         };
191 
192         if (node->paramChildrens[static_cast<size_t>(ParamType::STRING)] != 0U)
193         {
194             size_t epos = pos;
195             for (; epos < reqUrl.size(); epos++)
196             {
197                 if (reqUrl[epos] == '/')
198                 {
199                     break;
200                 }
201             }
202 
203             if (epos != pos)
204             {
205                 params->emplace_back(reqUrl.substr(pos, epos - pos));
206                 std::pair<unsigned, std::vector<std::string>> ret =
207                     find(reqUrl,
208                          &nodes[node->paramChildrens[static_cast<size_t>(
209                              ParamType::STRING)]],
210                          epos, params);
211                 updateFound(ret);
212                 params->pop_back();
213             }
214         }
215 
216         if (node->paramChildrens[static_cast<size_t>(ParamType::PATH)] != 0U)
217         {
218             size_t epos = reqUrl.size();
219 
220             if (epos != pos)
221             {
222                 params->emplace_back(reqUrl.substr(pos, epos - pos));
223                 std::pair<unsigned, std::vector<std::string>> ret =
224                     find(reqUrl,
225                          &nodes[node->paramChildrens[static_cast<size_t>(
226                              ParamType::PATH)]],
227                          epos, params);
228                 updateFound(ret);
229                 params->pop_back();
230             }
231         }
232 
233         for (const Node::ChildMap::value_type& kv : node->children)
234         {
235             const std::string& fragment = kv.first;
236             const Node* child = &nodes[kv.second];
237 
238             if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
239             {
240                 std::pair<unsigned, std::vector<std::string>> ret =
241                     find(reqUrl, child, pos + fragment.size(), params);
242                 updateFound(ret);
243             }
244         }
245 
246         return {found, matchParams};
247     }
248 
249     void add(const std::string& url, unsigned ruleIndex)
250     {
251         size_t idx = 0;
252 
253         for (unsigned i = 0; i < url.size(); i++)
254         {
255             char c = url[i];
256             if (c == '<')
257             {
258                 constexpr static std::array<
259                     std::pair<ParamType, std::string_view>, 3>
260                     paramTraits = {{
261                         {ParamType::STRING, "<str>"},
262                         {ParamType::STRING, "<string>"},
263                         {ParamType::PATH, "<path>"},
264                     }};
265 
266                 for (const std::pair<ParamType, std::string_view>& x :
267                      paramTraits)
268                 {
269                     if (url.compare(i, x.second.size(), x.second) == 0)
270                     {
271                         size_t index = static_cast<size_t>(x.first);
272                         if (nodes[idx].paramChildrens[index] == 0U)
273                         {
274                             unsigned newNodeIdx = newNode();
275                             nodes[idx].paramChildrens[index] = newNodeIdx;
276                         }
277                         idx = nodes[idx].paramChildrens[index];
278                         i += static_cast<unsigned>(x.second.size());
279                         break;
280                     }
281                 }
282 
283                 i--;
284             }
285             else
286             {
287                 std::string piece(&c, 1);
288                 if (nodes[idx].children.count(piece) == 0U)
289                 {
290                     unsigned newNodeIdx = newNode();
291                     nodes[idx].children.emplace(piece, newNodeIdx);
292                 }
293                 idx = nodes[idx].children[piece];
294             }
295         }
296         if (nodes[idx].ruleIndex != 0U)
297         {
298             throw std::runtime_error("handler already exists for " + url);
299         }
300         nodes[idx].ruleIndex = ruleIndex;
301     }
302 
303   private:
304     void debugNodePrint(Node* n, size_t level)
305     {
306         for (size_t i = 0; i < static_cast<size_t>(ParamType::MAX); i++)
307         {
308             if (n->paramChildrens[i] != 0U)
309             {
310                 BMCWEB_LOG_DEBUG << std::string(
311                     2U * level, ' ') /*<< "("<<n->paramChildrens[i]<<") "*/;
312                 switch (static_cast<ParamType>(i))
313                 {
314                     case ParamType::STRING:
315                         BMCWEB_LOG_DEBUG << "<str>";
316                         break;
317                     case ParamType::PATH:
318                         BMCWEB_LOG_DEBUG << "<path>";
319                         break;
320                     case ParamType::MAX:
321                         BMCWEB_LOG_DEBUG << "<ERROR>";
322                         break;
323                 }
324 
325                 debugNodePrint(&nodes[n->paramChildrens[i]], level + 1);
326             }
327         }
328         for (const Node::ChildMap::value_type& kv : n->children)
329         {
330             BMCWEB_LOG_DEBUG
331                 << std::string(2U * level, ' ') /*<< "(" << kv.second << ") "*/
332                 << kv.first;
333             debugNodePrint(&nodes[kv.second], level + 1);
334         }
335     }
336 
337   public:
338     void debugPrint()
339     {
340         debugNodePrint(head(), 0U);
341     }
342 
343   private:
344     const Node* head() const
345     {
346         return &nodes.front();
347     }
348 
349     Node* head()
350     {
351         return &nodes.front();
352     }
353 
354     unsigned newNode()
355     {
356         nodes.resize(nodes.size() + 1);
357         return static_cast<unsigned>(nodes.size() - 1);
358     }
359 
360     std::vector<Node> nodes;
361 };
362 
363 class Router
364 {
365   public:
366     Router() = default;
367 
368     DynamicRule& newRuleDynamic(const std::string& rule)
369     {
370         std::unique_ptr<DynamicRule> ruleObject =
371             std::make_unique<DynamicRule>(rule);
372         DynamicRule* ptr = ruleObject.get();
373         allRules.emplace_back(std::move(ruleObject));
374 
375         return *ptr;
376     }
377 
378     template <uint64_t N>
379     typename black_magic::Arguments<N>::type::template rebind<TaggedRule>&
380         newRuleTagged(const std::string& rule)
381     {
382         using RuleT = typename black_magic::Arguments<N>::type::template rebind<
383             TaggedRule>;
384         std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
385         RuleT* ptr = ruleObject.get();
386         allRules.emplace_back(std::move(ruleObject));
387 
388         return *ptr;
389     }
390 
391     void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
392     {
393         if (ruleObject == nullptr)
394         {
395             return;
396         }
397         for (size_t method = 0, methodBit = 1; method <= methodNotAllowedIndex;
398              method++, methodBit <<= 1)
399         {
400             if ((ruleObject->methodsBitfield & methodBit) > 0U)
401             {
402                 perMethods[method].rules.emplace_back(ruleObject);
403                 perMethods[method].trie.add(
404                     rule, static_cast<unsigned>(
405                               perMethods[method].rules.size() - 1U));
406                 // directory case:
407                 //   request to `/about' url matches `/about/' rule
408                 if (rule.size() > 2 && rule.back() == '/')
409                 {
410                     perMethods[method].trie.add(
411                         rule.substr(0, rule.size() - 1),
412                         static_cast<unsigned>(perMethods[method].rules.size() -
413                                               1));
414                 }
415             }
416         }
417     }
418 
419     void validate()
420     {
421         for (std::unique_ptr<BaseRule>& rule : allRules)
422         {
423             if (rule)
424             {
425                 std::unique_ptr<BaseRule> upgraded = rule->upgrade();
426                 if (upgraded)
427                 {
428                     rule = std::move(upgraded);
429                 }
430                 rule->validate();
431                 internalAddRuleObject(rule->rule, rule.get());
432             }
433         }
434         for (PerMethod& perMethod : perMethods)
435         {
436             perMethod.trie.validate();
437         }
438     }
439 
440     struct FindRoute
441     {
442         BaseRule* rule = nullptr;
443         std::vector<std::string> params;
444     };
445 
446     struct FindRouteResponse
447     {
448         std::string allowHeader;
449         FindRoute route;
450     };
451 
452     FindRoute findRouteByIndex(std::string_view url, size_t index) const
453     {
454         FindRoute route;
455         if (index >= perMethods.size())
456         {
457             BMCWEB_LOG_CRITICAL << "Bad index???";
458             return route;
459         }
460         const PerMethod& perMethod = perMethods[index];
461         std::pair<unsigned, std::vector<std::string>> found =
462             perMethod.trie.find(url);
463         if (found.first >= perMethod.rules.size())
464         {
465             throw std::runtime_error("Trie internal structure corrupted!");
466         }
467         // Found a 404 route, switch that in
468         if (found.first != 0U)
469         {
470             route.rule = perMethod.rules[found.first];
471             route.params = std::move(found.second);
472         }
473         return route;
474     }
475 
476     FindRouteResponse findRoute(Request& req) const
477     {
478         FindRouteResponse findRoute;
479 
480         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
481         if (!verb)
482         {
483             return findRoute;
484         }
485         size_t reqMethodIndex = static_cast<size_t>(*verb);
486         // Check to see if this url exists at any verb
487         for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
488              perMethodIndex++)
489         {
490             // Make sure it's safe to deference the array at that index
491             static_assert(maxVerbIndex <
492                           std::tuple_size_v<decltype(perMethods)>);
493             FindRoute route = findRouteByIndex(req.url().encoded_path(),
494                                                perMethodIndex);
495             if (route.rule == nullptr)
496             {
497                 continue;
498             }
499             if (!findRoute.allowHeader.empty())
500             {
501                 findRoute.allowHeader += ", ";
502             }
503             HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
504             findRoute.allowHeader += httpVerbToString(thisVerb);
505             if (perMethodIndex == reqMethodIndex)
506             {
507                 findRoute.route = route;
508             }
509         }
510         return findRoute;
511     }
512 
513     template <typename Adaptor>
514     void handleUpgrade(Request& req,
515                        const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
516                        Adaptor&& adaptor)
517     {
518         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
519         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
520         {
521             asyncResp->res.result(boost::beast::http::status::not_found);
522             return;
523         }
524         PerMethod& perMethod = perMethods[static_cast<size_t>(*verb)];
525         Trie& trie = perMethod.trie;
526         std::vector<BaseRule*>& rules = perMethod.rules;
527 
528         const std::pair<unsigned, std::vector<std::string>>& found =
529             trie.find(req.url().encoded_path());
530         unsigned ruleIndex = found.first;
531         if (ruleIndex == 0U)
532         {
533             BMCWEB_LOG_DEBUG << "Cannot match rules "
534                              << req.url().encoded_path();
535             asyncResp->res.result(boost::beast::http::status::not_found);
536             return;
537         }
538 
539         if (ruleIndex >= rules.size())
540         {
541             throw std::runtime_error("Trie internal structure corrupted!");
542         }
543 
544         BaseRule& rule = *rules[ruleIndex];
545         size_t methods = rule.getMethods();
546         if ((methods & (1U << static_cast<size_t>(*verb))) == 0)
547         {
548             BMCWEB_LOG_DEBUG
549                 << "Rule found but method mismatch: "
550                 << req.url().encoded_path() << " with " << req.methodString()
551                 << "(" << static_cast<uint32_t>(*verb) << ") / " << methods;
552             asyncResp->res.result(boost::beast::http::status::not_found);
553             return;
554         }
555 
556         BMCWEB_LOG_DEBUG << "Matched rule (upgrade) '" << rule.rule << "' "
557                          << static_cast<uint32_t>(*verb) << " / " << methods;
558 
559         // TODO(ed) This should be able to use std::bind_front, but it doesn't
560         // appear to work with the std::move on adaptor.
561         validatePrivilege(
562             req, asyncResp, rule,
563             [&rule, asyncResp, adaptor(std::forward<Adaptor>(adaptor))](
564                 Request& thisReq) mutable {
565             rule.handleUpgrade(thisReq, asyncResp, std::move(adaptor));
566             });
567     }
568 
569     void handle(Request& req,
570                 const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
571     {
572         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
573         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
574         {
575             asyncResp->res.result(boost::beast::http::status::not_found);
576             return;
577         }
578 
579         FindRouteResponse foundRoute = findRoute(req);
580 
581         if (foundRoute.route.rule == nullptr)
582         {
583             // Couldn't find a normal route with any verb, try looking for a 404
584             // route
585             if (foundRoute.allowHeader.empty())
586             {
587                 foundRoute.route = findRouteByIndex(req.url().encoded_path(),
588                                                     notFoundIndex);
589             }
590             else
591             {
592                 // See if we have a method not allowed (405) handler
593                 foundRoute.route = findRouteByIndex(req.url().encoded_path(),
594                                                     methodNotAllowedIndex);
595             }
596         }
597 
598         // Fill in the allow header if it's valid
599         if (!foundRoute.allowHeader.empty())
600         {
601             asyncResp->res.addHeader(boost::beast::http::field::allow,
602                                      foundRoute.allowHeader);
603         }
604 
605         // If we couldn't find a real route or a 404 route, return a generic
606         // response
607         if (foundRoute.route.rule == nullptr)
608         {
609             if (foundRoute.allowHeader.empty())
610             {
611                 asyncResp->res.result(boost::beast::http::status::not_found);
612             }
613             else
614             {
615                 asyncResp->res.result(
616                     boost::beast::http::status::method_not_allowed);
617             }
618             return;
619         }
620 
621         BaseRule& rule = *foundRoute.route.rule;
622         std::vector<std::string> params = std::move(foundRoute.route.params);
623 
624         BMCWEB_LOG_DEBUG << "Matched rule '" << rule.rule << "' "
625                          << static_cast<uint32_t>(*verb) << " / "
626                          << rule.getMethods();
627 
628         if (req.session == nullptr)
629         {
630             rule.handle(req, asyncResp, params);
631             return;
632         }
633         validatePrivilege(req, asyncResp, rule,
634                           [&rule, asyncResp, params](Request& thisReq) mutable {
635             rule.handle(thisReq, asyncResp, params);
636         });
637     }
638 
639     void debugPrint()
640     {
641         for (size_t i = 0; i < perMethods.size(); i++)
642         {
643             BMCWEB_LOG_DEBUG << boost::beast::http::to_string(
644                 static_cast<boost::beast::http::verb>(i));
645             perMethods[i].trie.debugPrint();
646         }
647     }
648 
649     std::vector<const std::string*> getRoutes(const std::string& parent)
650     {
651         std::vector<const std::string*> ret;
652 
653         for (const PerMethod& pm : perMethods)
654         {
655             std::vector<unsigned> x;
656             pm.trie.findRouteIndexes(parent, x);
657             for (unsigned index : x)
658             {
659                 ret.push_back(&pm.rules[index]->rule);
660             }
661         }
662         return ret;
663     }
664 
665   private:
666     struct PerMethod
667     {
668         std::vector<BaseRule*> rules;
669         Trie trie;
670         // rule index 0 has special meaning; preallocate it to avoid
671         // duplication.
672         PerMethod() : rules(1) {}
673     };
674 
675     std::array<PerMethod, methodNotAllowedIndex + 1> perMethods;
676     std::vector<std::unique_ptr<BaseRule>> allRules;
677 };
678 } // namespace crow
679