xref: /openbmc/bmcweb/http/routing.hpp (revision 720c9898)
1 #pragma once
2 
3 #include "async_resp.hpp"
4 #include "dbus_privileges.hpp"
5 #include "dbus_utility.hpp"
6 #include "error_messages.hpp"
7 #include "http_request.hpp"
8 #include "http_response.hpp"
9 #include "logging.hpp"
10 #include "privileges.hpp"
11 #include "routing/baserule.hpp"
12 #include "routing/dynamicrule.hpp"
13 #include "routing/sserule.hpp"
14 #include "routing/taggedrule.hpp"
15 #include "routing/websocketrule.hpp"
16 #include "sessions.hpp"
17 #include "utility.hpp"
18 #include "utils/dbus_utils.hpp"
19 #include "verb.hpp"
20 #include "websocket.hpp"
21 
22 #include <boost/container/flat_map.hpp>
23 #include <boost/container/small_vector.hpp>
24 
25 #include <algorithm>
26 #include <cerrno>
27 #include <cstdint>
28 #include <cstdlib>
29 #include <limits>
30 #include <memory>
31 #include <optional>
32 #include <tuple>
33 #include <utility>
34 #include <vector>
35 
36 namespace crow
37 {
38 
39 class Trie
40 {
41   public:
42     struct Node
43     {
44         unsigned ruleIndex = 0U;
45 
46         size_t stringParamChild = 0U;
47         size_t pathParamChild = 0U;
48 
49         using ChildMap = boost::container::flat_map<
50             std::string, unsigned, std::less<>,
51             boost::container::small_vector<std::pair<std::string, unsigned>,
52                                            1>>;
53         ChildMap children;
54 
55         bool isSimpleNode() const
56         {
57             return ruleIndex == 0 && stringParamChild == 0 &&
58                    pathParamChild == 0;
59         }
60     };
61 
62     Trie() : nodes(1) {}
63 
64   private:
65     void optimizeNode(Node& node)
66     {
67         if (node.stringParamChild != 0U)
68         {
69             optimizeNode(nodes[node.stringParamChild]);
70         }
71         if (node.pathParamChild != 0U)
72         {
73             optimizeNode(nodes[node.pathParamChild]);
74         }
75 
76         if (node.children.empty())
77         {
78             return;
79         }
80         while (true)
81         {
82             bool didMerge = false;
83             Node::ChildMap merged;
84             for (const Node::ChildMap::value_type& kv : node.children)
85             {
86                 Node& child = nodes[kv.second];
87                 if (child.isSimpleNode())
88                 {
89                     for (const Node::ChildMap::value_type& childKv :
90                          child.children)
91                     {
92                         merged[kv.first + childKv.first] = childKv.second;
93                         didMerge = true;
94                     }
95                 }
96                 else
97                 {
98                     merged[kv.first] = kv.second;
99                 }
100             }
101             node.children = std::move(merged);
102             if (!didMerge)
103             {
104                 break;
105             }
106         }
107 
108         for (const Node::ChildMap::value_type& kv : node.children)
109         {
110             optimizeNode(nodes[kv.second]);
111         }
112     }
113 
114     void optimize()
115     {
116         optimizeNode(head());
117     }
118 
119   public:
120     void validate()
121     {
122         optimize();
123     }
124 
125     void findRouteIndexesHelper(std::string_view reqUrl,
126                                 std::vector<unsigned>& routeIndexes,
127                                 const Node& node) const
128     {
129         for (const Node::ChildMap::value_type& kv : node.children)
130         {
131             const std::string& fragment = kv.first;
132             const Node& child = nodes[kv.second];
133             if (reqUrl.empty())
134             {
135                 if (child.ruleIndex != 0 && fragment != "/")
136                 {
137                     routeIndexes.push_back(child.ruleIndex);
138                 }
139                 findRouteIndexesHelper(reqUrl, routeIndexes, child);
140             }
141             else
142             {
143                 if (reqUrl.starts_with(fragment))
144                 {
145                     findRouteIndexesHelper(reqUrl.substr(fragment.size()),
146                                            routeIndexes, child);
147                 }
148             }
149         }
150     }
151 
152     void findRouteIndexes(const std::string& reqUrl,
153                           std::vector<unsigned>& routeIndexes) const
154     {
155         findRouteIndexesHelper(reqUrl, routeIndexes, head());
156     }
157 
158     struct FindResult
159     {
160         unsigned ruleIndex;
161         std::vector<std::string> params;
162     };
163 
164   private:
165     FindResult findHelper(const std::string_view reqUrl, const Node& node,
166                           std::vector<std::string>& params) const
167     {
168         if (reqUrl.empty())
169         {
170             return {node.ruleIndex, params};
171         }
172 
173         if (node.stringParamChild != 0U)
174         {
175             size_t epos = 0;
176             for (; epos < reqUrl.size(); epos++)
177             {
178                 if (reqUrl[epos] == '/')
179                 {
180                     break;
181                 }
182             }
183 
184             if (epos != 0)
185             {
186                 params.emplace_back(reqUrl.substr(0, epos));
187                 FindResult ret = findHelper(
188                     reqUrl.substr(epos), nodes[node.stringParamChild], params);
189                 if (ret.ruleIndex != 0U)
190                 {
191                     return {ret.ruleIndex, std::move(ret.params)};
192                 }
193                 params.pop_back();
194             }
195         }
196 
197         if (node.pathParamChild != 0U)
198         {
199             params.emplace_back(reqUrl);
200             FindResult ret = findHelper("", nodes[node.pathParamChild], params);
201             if (ret.ruleIndex != 0U)
202             {
203                 return {ret.ruleIndex, std::move(ret.params)};
204             }
205             params.pop_back();
206         }
207 
208         for (const Node::ChildMap::value_type& kv : node.children)
209         {
210             const std::string& fragment = kv.first;
211             const Node& child = nodes[kv.second];
212 
213             if (reqUrl.starts_with(fragment))
214             {
215                 FindResult ret = findHelper(reqUrl.substr(fragment.size()),
216                                             child, params);
217                 if (ret.ruleIndex != 0U)
218                 {
219                     return {ret.ruleIndex, std::move(ret.params)};
220                 }
221             }
222         }
223 
224         return {0U, std::vector<std::string>()};
225     }
226 
227   public:
228     FindResult find(const std::string_view reqUrl) const
229     {
230         std::vector<std::string> start;
231         return findHelper(reqUrl, head(), start);
232     }
233 
234     void add(std::string_view url, unsigned ruleIndex)
235     {
236         size_t idx = 0;
237 
238         while (!url.empty())
239         {
240             char c = url[0];
241             if (c == '<')
242             {
243                 bool found = false;
244                 for (const std::string_view str1 :
245                      {"<str>", "<string>", "<path>"})
246                 {
247                     if (!url.starts_with(str1))
248                     {
249                         continue;
250                     }
251                     found = true;
252                     Node& node = nodes[idx];
253                     size_t* param = &node.stringParamChild;
254                     if (str1 == "<path>")
255                     {
256                         param = &node.pathParamChild;
257                     }
258                     if (*param == 0U)
259                     {
260                         *param = newNode();
261                     }
262                     idx = *param;
263 
264                     url.remove_prefix(str1.size());
265                     break;
266                 }
267                 if (found)
268                 {
269                     continue;
270                 }
271 
272                 BMCWEB_LOG_CRITICAL("Cant find tag for {}", url);
273                 return;
274             }
275             std::string piece(&c, 1);
276             if (!nodes[idx].children.contains(piece))
277             {
278                 unsigned newNodeIdx = newNode();
279                 nodes[idx].children.emplace(piece, newNodeIdx);
280             }
281             idx = nodes[idx].children[piece];
282             url.remove_prefix(1);
283         }
284         if (nodes[idx].ruleIndex != 0U)
285         {
286             throw std::runtime_error(
287                 std::format("handler already exists for {}", url));
288         }
289         nodes[idx].ruleIndex = ruleIndex;
290     }
291 
292   private:
293     void debugNodePrint(Node& n, size_t level)
294     {
295         std::string spaces(level, ' ');
296         if (n.stringParamChild != 0U)
297         {
298             BMCWEB_LOG_DEBUG("{}<str>", spaces);
299             debugNodePrint(nodes[n.stringParamChild], level + 5);
300         }
301         if (n.pathParamChild != 0U)
302         {
303             BMCWEB_LOG_DEBUG("{} <path>", spaces);
304             debugNodePrint(nodes[n.pathParamChild], level + 6);
305         }
306         for (const Node::ChildMap::value_type& kv : n.children)
307         {
308             BMCWEB_LOG_DEBUG("{}{}", spaces, kv.first);
309             debugNodePrint(nodes[kv.second], level + kv.first.size());
310         }
311     }
312 
313   public:
314     void debugPrint()
315     {
316         debugNodePrint(head(), 0U);
317     }
318 
319   private:
320     const Node& head() const
321     {
322         return nodes.front();
323     }
324 
325     Node& head()
326     {
327         return nodes.front();
328     }
329 
330     unsigned newNode()
331     {
332         nodes.resize(nodes.size() + 1);
333         return static_cast<unsigned>(nodes.size() - 1);
334     }
335 
336     std::vector<Node> nodes;
337 };
338 
339 class Router
340 {
341   public:
342     Router() = default;
343 
344     DynamicRule& newRuleDynamic(const std::string& rule)
345     {
346         std::unique_ptr<DynamicRule> ruleObject =
347             std::make_unique<DynamicRule>(rule);
348         DynamicRule* ptr = ruleObject.get();
349         allRules.emplace_back(std::move(ruleObject));
350 
351         return *ptr;
352     }
353 
354     template <uint64_t NumArgs>
355     auto& newRuleTagged(const std::string& rule)
356     {
357         if constexpr (NumArgs == 0)
358         {
359             using RuleT = TaggedRule<>;
360             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
361             RuleT* ptr = ruleObject.get();
362             allRules.emplace_back(std::move(ruleObject));
363             return *ptr;
364         }
365         else if constexpr (NumArgs == 1)
366         {
367             using RuleT = TaggedRule<std::string>;
368             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
369             RuleT* ptr = ruleObject.get();
370             allRules.emplace_back(std::move(ruleObject));
371             return *ptr;
372         }
373         else if constexpr (NumArgs == 2)
374         {
375             using RuleT = TaggedRule<std::string, std::string>;
376             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
377             RuleT* ptr = ruleObject.get();
378             allRules.emplace_back(std::move(ruleObject));
379             return *ptr;
380         }
381         else if constexpr (NumArgs == 3)
382         {
383             using RuleT = TaggedRule<std::string, std::string, std::string>;
384             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
385             RuleT* ptr = ruleObject.get();
386             allRules.emplace_back(std::move(ruleObject));
387             return *ptr;
388         }
389         else if constexpr (NumArgs == 4)
390         {
391             using RuleT =
392                 TaggedRule<std::string, std::string, std::string, std::string>;
393             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
394             RuleT* ptr = ruleObject.get();
395             allRules.emplace_back(std::move(ruleObject));
396             return *ptr;
397         }
398         else
399         {
400             using RuleT = TaggedRule<std::string, std::string, std::string,
401                                      std::string, std::string>;
402             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
403             RuleT* ptr = ruleObject.get();
404             allRules.emplace_back(std::move(ruleObject));
405             return *ptr;
406         }
407         static_assert(NumArgs <= 5, "Max number of args supported is 5");
408     }
409 
410     void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
411     {
412         if (ruleObject == nullptr)
413         {
414             return;
415         }
416         for (size_t method = 0, methodBit = 1; method <= methodNotAllowedIndex;
417              method++, methodBit <<= 1)
418         {
419             if ((ruleObject->methodsBitfield & methodBit) > 0U)
420             {
421                 perMethods[method].rules.emplace_back(ruleObject);
422                 perMethods[method].trie.add(
423                     rule, static_cast<unsigned>(
424                               perMethods[method].rules.size() - 1U));
425                 // directory case:
426                 //   request to `/about' url matches `/about/' rule
427                 if (rule.size() > 2 && rule.back() == '/')
428                 {
429                     perMethods[method].trie.add(
430                         rule.substr(0, rule.size() - 1),
431                         static_cast<unsigned>(perMethods[method].rules.size() -
432                                               1));
433                 }
434             }
435         }
436     }
437 
438     void validate()
439     {
440         for (std::unique_ptr<BaseRule>& rule : allRules)
441         {
442             if (rule)
443             {
444                 std::unique_ptr<BaseRule> upgraded = rule->upgrade();
445                 if (upgraded)
446                 {
447                     rule = std::move(upgraded);
448                 }
449                 rule->validate();
450                 internalAddRuleObject(rule->rule, rule.get());
451             }
452         }
453         for (PerMethod& perMethod : perMethods)
454         {
455             perMethod.trie.validate();
456         }
457     }
458 
459     struct FindRoute
460     {
461         BaseRule* rule = nullptr;
462         std::vector<std::string> params;
463     };
464 
465     struct FindRouteResponse
466     {
467         std::string allowHeader;
468         FindRoute route;
469     };
470 
471     FindRoute findRouteByIndex(std::string_view url, size_t index) const
472     {
473         FindRoute route;
474         if (index >= perMethods.size())
475         {
476             BMCWEB_LOG_CRITICAL("Bad index???");
477             return route;
478         }
479         const PerMethod& perMethod = perMethods[index];
480         Trie::FindResult found = perMethod.trie.find(url);
481         if (found.ruleIndex >= perMethod.rules.size())
482         {
483             throw std::runtime_error("Trie internal structure corrupted!");
484         }
485         // Found a 404 route, switch that in
486         if (found.ruleIndex != 0U)
487         {
488             route.rule = perMethod.rules[found.ruleIndex];
489             route.params = std::move(found.params);
490         }
491         return route;
492     }
493 
494     FindRouteResponse findRoute(const Request& req) const
495     {
496         FindRouteResponse findRoute;
497 
498         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
499         if (!verb)
500         {
501             return findRoute;
502         }
503         size_t reqMethodIndex = static_cast<size_t>(*verb);
504         // Check to see if this url exists at any verb
505         for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
506              perMethodIndex++)
507         {
508             // Make sure it's safe to deference the array at that index
509             static_assert(maxVerbIndex <
510                           std::tuple_size_v<decltype(perMethods)>);
511             FindRoute route = findRouteByIndex(req.url().encoded_path(),
512                                                perMethodIndex);
513             if (route.rule == nullptr)
514             {
515                 continue;
516             }
517             if (!findRoute.allowHeader.empty())
518             {
519                 findRoute.allowHeader += ", ";
520             }
521             HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
522             findRoute.allowHeader += httpVerbToString(thisVerb);
523             if (perMethodIndex == reqMethodIndex)
524             {
525                 findRoute.route = route;
526             }
527         }
528         return findRoute;
529     }
530 
531     template <typename Adaptor>
532     void handleUpgrade(const std::shared_ptr<Request>& req,
533                        const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
534                        Adaptor&& adaptor)
535     {
536         std::optional<HttpVerb> verb = httpVerbFromBoost(req->method());
537         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
538         {
539             asyncResp->res.result(boost::beast::http::status::not_found);
540             return;
541         }
542         PerMethod& perMethod = perMethods[static_cast<size_t>(*verb)];
543         Trie& trie = perMethod.trie;
544         std::vector<BaseRule*>& rules = perMethod.rules;
545 
546         Trie::FindResult found = trie.find(req->url().encoded_path());
547         unsigned ruleIndex = found.ruleIndex;
548         if (ruleIndex == 0U)
549         {
550             BMCWEB_LOG_DEBUG("Cannot match rules {}",
551                              req->url().encoded_path());
552             asyncResp->res.result(boost::beast::http::status::not_found);
553             return;
554         }
555 
556         if (ruleIndex >= rules.size())
557         {
558             throw std::runtime_error("Trie internal structure corrupted!");
559         }
560 
561         BaseRule& rule = *rules[ruleIndex];
562         size_t methods = rule.getMethods();
563         if ((methods & (1U << static_cast<size_t>(*verb))) == 0)
564         {
565             BMCWEB_LOG_DEBUG(
566                 "Rule found but method mismatch: {} with {}({}) / {}",
567                 req->url().encoded_path(), req->methodString(),
568                 static_cast<uint32_t>(*verb), methods);
569             asyncResp->res.result(boost::beast::http::status::not_found);
570             return;
571         }
572 
573         BMCWEB_LOG_DEBUG("Matched rule (upgrade) '{}' {} / {}", rule.rule,
574                          static_cast<uint32_t>(*verb), methods);
575 
576         // TODO(ed) This should be able to use std::bind_front, but it doesn't
577         // appear to work with the std::move on adaptor.
578         validatePrivilege(req, asyncResp, rule,
579                           [req, &rule, asyncResp,
580                            adaptor = std::forward<Adaptor>(adaptor)]() mutable {
581             rule.handleUpgrade(*req, asyncResp, std::move(adaptor));
582         });
583     }
584 
585     void handle(const std::shared_ptr<Request>& req,
586                 const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
587     {
588         std::optional<HttpVerb> verb = httpVerbFromBoost(req->method());
589         if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
590         {
591             asyncResp->res.result(boost::beast::http::status::not_found);
592             return;
593         }
594 
595         FindRouteResponse foundRoute = findRoute(*req);
596 
597         if (foundRoute.route.rule == nullptr)
598         {
599             // Couldn't find a normal route with any verb, try looking for a 404
600             // route
601             if (foundRoute.allowHeader.empty())
602             {
603                 foundRoute.route = findRouteByIndex(req->url().encoded_path(),
604                                                     notFoundIndex);
605             }
606             else
607             {
608                 // See if we have a method not allowed (405) handler
609                 foundRoute.route = findRouteByIndex(req->url().encoded_path(),
610                                                     methodNotAllowedIndex);
611             }
612         }
613 
614         // Fill in the allow header if it's valid
615         if (!foundRoute.allowHeader.empty())
616         {
617             asyncResp->res.addHeader(boost::beast::http::field::allow,
618                                      foundRoute.allowHeader);
619         }
620 
621         // If we couldn't find a real route or a 404 route, return a generic
622         // response
623         if (foundRoute.route.rule == nullptr)
624         {
625             if (foundRoute.allowHeader.empty())
626             {
627                 asyncResp->res.result(boost::beast::http::status::not_found);
628             }
629             else
630             {
631                 asyncResp->res.result(
632                     boost::beast::http::status::method_not_allowed);
633             }
634             return;
635         }
636 
637         BaseRule& rule = *foundRoute.route.rule;
638         std::vector<std::string> params = std::move(foundRoute.route.params);
639 
640         BMCWEB_LOG_DEBUG("Matched rule '{}' {} / {}", rule.rule,
641                          static_cast<uint32_t>(*verb), rule.getMethods());
642 
643         if (req->session == nullptr)
644         {
645             rule.handle(*req, asyncResp, params);
646             return;
647         }
648         validatePrivilege(
649             req, asyncResp, rule,
650             [req, asyncResp, &rule, params = std::move(params)]() {
651             rule.handle(*req, asyncResp, params);
652         });
653     }
654 
655     void debugPrint()
656     {
657         for (size_t i = 0; i < perMethods.size(); i++)
658         {
659             BMCWEB_LOG_DEBUG("{}", httpVerbToString(static_cast<HttpVerb>(i)));
660             perMethods[i].trie.debugPrint();
661         }
662     }
663 
664     std::vector<const std::string*> getRoutes(const std::string& parent)
665     {
666         std::vector<const std::string*> ret;
667 
668         for (const PerMethod& pm : perMethods)
669         {
670             std::vector<unsigned> x;
671             pm.trie.findRouteIndexes(parent, x);
672             for (unsigned index : x)
673             {
674                 ret.push_back(&pm.rules[index]->rule);
675             }
676         }
677         return ret;
678     }
679 
680   private:
681     struct PerMethod
682     {
683         std::vector<BaseRule*> rules;
684         Trie trie;
685         // rule index 0 has special meaning; preallocate it to avoid
686         // duplication.
687         PerMethod() : rules(1) {}
688     };
689 
690     std::array<PerMethod, methodNotAllowedIndex + 1> perMethods;
691     std::vector<std::unique_ptr<BaseRule>> allRules;
692 };
693 } // namespace crow
694