xref: /openbmc/bmcweb/http/routing.hpp (revision 6c58a03e)
1 #pragma once
2 
3 #include "async_resp.hpp"
4 #include "dbus_privileges.hpp"
5 #include "dbus_utility.hpp"
6 #include "error_messages.hpp"
7 #include "http_request.hpp"
8 #include "http_response.hpp"
9 #include "logging.hpp"
10 #include "privileges.hpp"
11 #include "routing/baserule.hpp"
12 #include "routing/dynamicrule.hpp"
13 #include "routing/sserule.hpp"
14 #include "routing/taggedrule.hpp"
15 #include "routing/websocketrule.hpp"
16 #include "sessions.hpp"
17 #include "utility.hpp"
18 #include "utils/dbus_utils.hpp"
19 #include "verb.hpp"
20 #include "websocket.hpp"
21 
22 #include <boost/container/flat_map.hpp>
23 #include <boost/container/small_vector.hpp>
24 
25 #include <algorithm>
26 #include <cerrno>
27 #include <cstdint>
28 #include <cstdlib>
29 #include <limits>
30 #include <memory>
31 #include <optional>
32 #include <string_view>
33 #include <tuple>
34 #include <utility>
35 #include <vector>
36 
37 namespace crow
38 {
39 
40 class Trie
41 {
42   public:
43     struct Node
44     {
45         unsigned ruleIndex = 0U;
46 
47         size_t stringParamChild = 0U;
48         size_t pathParamChild = 0U;
49 
50         using ChildMap = boost::container::flat_map<
51             std::string, unsigned, std::less<>,
52             boost::container::small_vector<std::pair<std::string, unsigned>,
53                                            1>>;
54         ChildMap children;
55 
56         bool isSimpleNode() const
57         {
58             return ruleIndex == 0 && stringParamChild == 0 &&
59                    pathParamChild == 0;
60         }
61     };
62 
63     Trie() : nodes(1) {}
64 
65   private:
66     void optimizeNode(Node& node)
67     {
68         if (node.stringParamChild != 0U)
69         {
70             optimizeNode(nodes[node.stringParamChild]);
71         }
72         if (node.pathParamChild != 0U)
73         {
74             optimizeNode(nodes[node.pathParamChild]);
75         }
76 
77         if (node.children.empty())
78         {
79             return;
80         }
81         while (true)
82         {
83             bool didMerge = false;
84             Node::ChildMap merged;
85             for (const Node::ChildMap::value_type& kv : node.children)
86             {
87                 Node& child = nodes[kv.second];
88                 if (child.isSimpleNode())
89                 {
90                     for (const Node::ChildMap::value_type& childKv :
91                          child.children)
92                     {
93                         merged[kv.first + childKv.first] = childKv.second;
94                         didMerge = true;
95                     }
96                 }
97                 else
98                 {
99                     merged[kv.first] = kv.second;
100                 }
101             }
102             node.children = std::move(merged);
103             if (!didMerge)
104             {
105                 break;
106             }
107         }
108 
109         for (const Node::ChildMap::value_type& kv : node.children)
110         {
111             optimizeNode(nodes[kv.second]);
112         }
113     }
114 
115     void optimize()
116     {
117         optimizeNode(head());
118     }
119 
120   public:
121     void validate()
122     {
123         optimize();
124     }
125 
126     void findRouteIndexesHelper(std::string_view reqUrl,
127                                 std::vector<unsigned>& routeIndexes,
128                                 const Node& node) const
129     {
130         for (const Node::ChildMap::value_type& kv : node.children)
131         {
132             const std::string& fragment = kv.first;
133             const Node& child = nodes[kv.second];
134             if (reqUrl.empty())
135             {
136                 if (child.ruleIndex != 0 && fragment != "/")
137                 {
138                     routeIndexes.push_back(child.ruleIndex);
139                 }
140                 findRouteIndexesHelper(reqUrl, routeIndexes, child);
141             }
142             else
143             {
144                 if (reqUrl.starts_with(fragment))
145                 {
146                     findRouteIndexesHelper(reqUrl.substr(fragment.size()),
147                                            routeIndexes, child);
148                 }
149             }
150         }
151     }
152 
153     void findRouteIndexes(const std::string& reqUrl,
154                           std::vector<unsigned>& routeIndexes) const
155     {
156         findRouteIndexesHelper(reqUrl, routeIndexes, head());
157     }
158 
159     struct FindResult
160     {
161         unsigned ruleIndex;
162         std::vector<std::string> params;
163     };
164 
165   private:
166     FindResult findHelper(const std::string_view reqUrl, const Node& node,
167                           std::vector<std::string>& params) const
168     {
169         if (reqUrl.empty())
170         {
171             return {node.ruleIndex, params};
172         }
173 
174         if (node.stringParamChild != 0U)
175         {
176             size_t epos = 0;
177             for (; epos < reqUrl.size(); epos++)
178             {
179                 if (reqUrl[epos] == '/')
180                 {
181                     break;
182                 }
183             }
184 
185             if (epos != 0)
186             {
187                 params.emplace_back(reqUrl.substr(0, epos));
188                 FindResult ret = findHelper(
189                     reqUrl.substr(epos), nodes[node.stringParamChild], params);
190                 if (ret.ruleIndex != 0U)
191                 {
192                     return {ret.ruleIndex, std::move(ret.params)};
193                 }
194                 params.pop_back();
195             }
196         }
197 
198         if (node.pathParamChild != 0U)
199         {
200             params.emplace_back(reqUrl);
201             FindResult ret = findHelper("", nodes[node.pathParamChild], params);
202             if (ret.ruleIndex != 0U)
203             {
204                 return {ret.ruleIndex, std::move(ret.params)};
205             }
206             params.pop_back();
207         }
208 
209         for (const Node::ChildMap::value_type& kv : node.children)
210         {
211             const std::string& fragment = kv.first;
212             const Node& child = nodes[kv.second];
213 
214             if (reqUrl.starts_with(fragment))
215             {
216                 FindResult ret =
217                     findHelper(reqUrl.substr(fragment.size()), child, params);
218                 if (ret.ruleIndex != 0U)
219                 {
220                     return {ret.ruleIndex, std::move(ret.params)};
221                 }
222             }
223         }
224 
225         return {0U, std::vector<std::string>()};
226     }
227 
228   public:
229     FindResult find(const std::string_view reqUrl) const
230     {
231         std::vector<std::string> start;
232         return findHelper(reqUrl, head(), start);
233     }
234 
235     void add(std::string_view urlIn, unsigned ruleIndex)
236     {
237         size_t idx = 0;
238 
239         std::string_view url = urlIn;
240 
241         while (!url.empty())
242         {
243             char c = url[0];
244             if (c == '<')
245             {
246                 bool found = false;
247                 for (const std::string_view str1 :
248                      {"<str>", "<string>", "<path>"})
249                 {
250                     if (!url.starts_with(str1))
251                     {
252                         continue;
253                     }
254                     found = true;
255                     Node& node = nodes[idx];
256                     size_t* param = &node.stringParamChild;
257                     if (str1 == "<path>")
258                     {
259                         param = &node.pathParamChild;
260                     }
261                     if (*param == 0U)
262                     {
263                         *param = newNode();
264                     }
265                     idx = *param;
266 
267                     url.remove_prefix(str1.size());
268                     break;
269                 }
270                 if (found)
271                 {
272                     continue;
273                 }
274 
275                 BMCWEB_LOG_CRITICAL("Can't find tag for {}", urlIn);
276                 return;
277             }
278             std::string piece(&c, 1);
279             if (!nodes[idx].children.contains(piece))
280             {
281                 unsigned newNodeIdx = newNode();
282                 nodes[idx].children.emplace(piece, newNodeIdx);
283             }
284             idx = nodes[idx].children[piece];
285             url.remove_prefix(1);
286         }
287         Node& node = nodes[idx];
288         if (node.ruleIndex != 0U)
289         {
290             BMCWEB_LOG_CRITICAL("handler already exists for \"{}\"", urlIn);
291             throw std::runtime_error(
292                 std::format("handler already exists for \"{}\"", urlIn));
293         }
294         node.ruleIndex = ruleIndex;
295     }
296 
297   private:
298     void debugNodePrint(Node& n, size_t level)
299     {
300         std::string spaces(level, ' ');
301         if (n.stringParamChild != 0U)
302         {
303             BMCWEB_LOG_DEBUG("{}<str>", spaces);
304             debugNodePrint(nodes[n.stringParamChild], level + 5);
305         }
306         if (n.pathParamChild != 0U)
307         {
308             BMCWEB_LOG_DEBUG("{} <path>", spaces);
309             debugNodePrint(nodes[n.pathParamChild], level + 6);
310         }
311         for (const Node::ChildMap::value_type& kv : n.children)
312         {
313             BMCWEB_LOG_DEBUG("{}{}", spaces, kv.first);
314             debugNodePrint(nodes[kv.second], level + kv.first.size());
315         }
316     }
317 
318   public:
319     void debugPrint()
320     {
321         debugNodePrint(head(), 0U);
322     }
323 
324   private:
325     const Node& head() const
326     {
327         return nodes.front();
328     }
329 
330     Node& head()
331     {
332         return nodes.front();
333     }
334 
335     unsigned newNode()
336     {
337         nodes.resize(nodes.size() + 1);
338         return static_cast<unsigned>(nodes.size() - 1);
339     }
340 
341     std::vector<Node> nodes;
342 };
343 
344 class Router
345 {
346   public:
347     Router() = default;
348 
349     DynamicRule& newRuleDynamic(const std::string& rule)
350     {
351         std::unique_ptr<DynamicRule> ruleObject =
352             std::make_unique<DynamicRule>(rule);
353         DynamicRule* ptr = ruleObject.get();
354         allRules.emplace_back(std::move(ruleObject));
355 
356         return *ptr;
357     }
358 
359     template <uint64_t NumArgs>
360     auto& newRuleTagged(const std::string& rule)
361     {
362         if constexpr (NumArgs == 0)
363         {
364             using RuleT = TaggedRule<>;
365             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
366             RuleT* ptr = ruleObject.get();
367             allRules.emplace_back(std::move(ruleObject));
368             return *ptr;
369         }
370         else if constexpr (NumArgs == 1)
371         {
372             using RuleT = TaggedRule<std::string>;
373             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
374             RuleT* ptr = ruleObject.get();
375             allRules.emplace_back(std::move(ruleObject));
376             return *ptr;
377         }
378         else if constexpr (NumArgs == 2)
379         {
380             using RuleT = TaggedRule<std::string, std::string>;
381             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
382             RuleT* ptr = ruleObject.get();
383             allRules.emplace_back(std::move(ruleObject));
384             return *ptr;
385         }
386         else if constexpr (NumArgs == 3)
387         {
388             using RuleT = TaggedRule<std::string, std::string, std::string>;
389             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
390             RuleT* ptr = ruleObject.get();
391             allRules.emplace_back(std::move(ruleObject));
392             return *ptr;
393         }
394         else if constexpr (NumArgs == 4)
395         {
396             using RuleT =
397                 TaggedRule<std::string, std::string, std::string, std::string>;
398             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
399             RuleT* ptr = ruleObject.get();
400             allRules.emplace_back(std::move(ruleObject));
401             return *ptr;
402         }
403         else
404         {
405             using RuleT = TaggedRule<std::string, std::string, std::string,
406                                      std::string, std::string>;
407             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
408             RuleT* ptr = ruleObject.get();
409             allRules.emplace_back(std::move(ruleObject));
410             return *ptr;
411         }
412         static_assert(NumArgs <= 5, "Max number of args supported is 5");
413     }
414 
415     struct PerMethod
416     {
417         std::vector<BaseRule*> rules;
418         Trie trie;
419         // rule index 0 has special meaning; preallocate it to avoid
420         // duplication.
421         PerMethod() : rules(1) {}
422 
423         void internalAdd(std::string_view rule, BaseRule* ruleObject)
424         {
425             rules.emplace_back(ruleObject);
426             trie.add(rule, static_cast<unsigned>(rules.size() - 1U));
427             // directory case:
428             //   request to `/about' url matches `/about/' rule
429             if (rule.size() > 2 && rule.back() == '/')
430             {
431                 trie.add(rule.substr(0, rule.size() - 1),
432                          static_cast<unsigned>(rules.size() - 1));
433             }
434         }
435     };
436 
437     void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
438     {
439         if (ruleObject == nullptr)
440         {
441             return;
442         }
443         for (size_t method = 0; method <= maxVerbIndex; method++)
444         {
445             size_t methodBit = 1 << method;
446             if ((ruleObject->methodsBitfield & methodBit) > 0U)
447             {
448                 perMethods[method].internalAdd(rule, ruleObject);
449             }
450         }
451 
452         if (ruleObject->isNotFound)
453         {
454             notFoundRoutes.internalAdd(rule, ruleObject);
455         }
456 
457         if (ruleObject->isMethodNotAllowed)
458         {
459             methodNotAllowedRoutes.internalAdd(rule, ruleObject);
460         }
461 
462         if (ruleObject->isUpgrade)
463         {
464             upgradeRoutes.internalAdd(rule, ruleObject);
465         }
466     }
467 
468     void validate()
469     {
470         for (std::unique_ptr<BaseRule>& rule : allRules)
471         {
472             if (rule)
473             {
474                 std::unique_ptr<BaseRule> upgraded = rule->upgrade();
475                 if (upgraded)
476                 {
477                     rule = std::move(upgraded);
478                 }
479                 rule->validate();
480                 internalAddRuleObject(rule->rule, rule.get());
481             }
482         }
483         for (PerMethod& perMethod : perMethods)
484         {
485             perMethod.trie.validate();
486         }
487     }
488 
489     struct FindRoute
490     {
491         BaseRule* rule = nullptr;
492         std::vector<std::string> params;
493     };
494 
495     struct FindRouteResponse
496     {
497         std::string allowHeader;
498         FindRoute route;
499     };
500 
501     static FindRoute findRouteByPerMethod(std::string_view url,
502                                           const PerMethod& perMethod)
503     {
504         FindRoute route;
505 
506         Trie::FindResult found = perMethod.trie.find(url);
507         if (found.ruleIndex >= perMethod.rules.size())
508         {
509             throw std::runtime_error("Trie internal structure corrupted!");
510         }
511         // Found a 404 route, switch that in
512         if (found.ruleIndex != 0U)
513         {
514             route.rule = perMethod.rules[found.ruleIndex];
515             route.params = std::move(found.params);
516         }
517         return route;
518     }
519 
520     FindRouteResponse findRoute(const Request& req) const
521     {
522         FindRouteResponse findRoute;
523 
524         // Check to see if this url exists at any verb
525         for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
526              perMethodIndex++)
527         {
528             // Make sure it's safe to deference the array at that index
529             static_assert(
530                 maxVerbIndex < std::tuple_size_v<decltype(perMethods)>);
531             FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
532                                                    perMethods[perMethodIndex]);
533             if (route.rule == nullptr)
534             {
535                 continue;
536             }
537             if (!findRoute.allowHeader.empty())
538             {
539                 findRoute.allowHeader += ", ";
540             }
541             HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
542             findRoute.allowHeader += httpVerbToString(thisVerb);
543         }
544 
545         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
546         if (!verb)
547         {
548             return findRoute;
549         }
550         size_t reqMethodIndex = static_cast<size_t>(*verb);
551         if (reqMethodIndex >= perMethods.size())
552         {
553             return findRoute;
554         }
555 
556         FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
557                                                perMethods[reqMethodIndex]);
558         if (route.rule != nullptr)
559         {
560             findRoute.route = route;
561         }
562 
563         return findRoute;
564     }
565 
566     template <typename Adaptor>
567     void handleUpgrade(const std::shared_ptr<Request>& req,
568                        const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
569                        Adaptor&& adaptor)
570     {
571         PerMethod& perMethod = upgradeRoutes;
572         Trie& trie = perMethod.trie;
573         std::vector<BaseRule*>& rules = perMethod.rules;
574 
575         Trie::FindResult found = trie.find(req->url().encoded_path());
576         unsigned ruleIndex = found.ruleIndex;
577         if (ruleIndex == 0U)
578         {
579             BMCWEB_LOG_DEBUG("Cannot match rules {}",
580                              req->url().encoded_path());
581             asyncResp->res.result(boost::beast::http::status::not_found);
582             return;
583         }
584 
585         if (ruleIndex >= rules.size())
586         {
587             throw std::runtime_error("Trie internal structure corrupted!");
588         }
589 
590         BaseRule& rule = *rules[ruleIndex];
591 
592         BMCWEB_LOG_DEBUG("Matched rule (upgrade) '{}'", rule.rule);
593 
594         // TODO(ed) This should be able to use std::bind_front, but it doesn't
595         // appear to work with the std::move on adaptor.
596         validatePrivilege(
597             req, asyncResp, rule,
598             [req, &rule, asyncResp,
599              adaptor = std::forward<Adaptor>(adaptor)]() mutable {
600                 rule.handleUpgrade(*req, asyncResp, std::move(adaptor));
601             });
602     }
603 
604     void handle(const std::shared_ptr<Request>& req,
605                 const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
606     {
607         FindRouteResponse foundRoute = findRoute(*req);
608 
609         if (foundRoute.route.rule == nullptr)
610         {
611             // Couldn't find a normal route with any verb, try looking for a 404
612             // route
613             if (foundRoute.allowHeader.empty())
614             {
615                 foundRoute.route = findRouteByPerMethod(
616                     req->url().encoded_path(), notFoundRoutes);
617             }
618             else
619             {
620                 // See if we have a method not allowed (405) handler
621                 foundRoute.route = findRouteByPerMethod(
622                     req->url().encoded_path(), methodNotAllowedRoutes);
623             }
624         }
625 
626         // Fill in the allow header if it's valid
627         if (!foundRoute.allowHeader.empty())
628         {
629             asyncResp->res.addHeader(boost::beast::http::field::allow,
630                                      foundRoute.allowHeader);
631         }
632 
633         // If we couldn't find a real route or a 404 route, return a generic
634         // response
635         if (foundRoute.route.rule == nullptr)
636         {
637             if (foundRoute.allowHeader.empty())
638             {
639                 asyncResp->res.result(boost::beast::http::status::not_found);
640             }
641             else
642             {
643                 asyncResp->res.result(
644                     boost::beast::http::status::method_not_allowed);
645             }
646             return;
647         }
648 
649         BaseRule& rule = *foundRoute.route.rule;
650         std::vector<std::string> params = std::move(foundRoute.route.params);
651 
652         BMCWEB_LOG_DEBUG("Matched rule '{}' {} / {}", rule.rule,
653                          req->methodString(), rule.getMethods());
654 
655         if (req->session == nullptr)
656         {
657             rule.handle(*req, asyncResp, params);
658             return;
659         }
660         validatePrivilege(
661             req, asyncResp, rule,
662             [req, asyncResp, &rule, params = std::move(params)]() {
663                 rule.handle(*req, asyncResp, params);
664             });
665     }
666 
667     void debugPrint()
668     {
669         for (size_t i = 0; i < perMethods.size(); i++)
670         {
671             BMCWEB_LOG_DEBUG("{}", httpVerbToString(static_cast<HttpVerb>(i)));
672             perMethods[i].trie.debugPrint();
673         }
674     }
675 
676     std::vector<const std::string*> getRoutes(const std::string& parent)
677     {
678         std::vector<const std::string*> ret;
679 
680         for (const PerMethod& pm : perMethods)
681         {
682             std::vector<unsigned> x;
683             pm.trie.findRouteIndexes(parent, x);
684             for (unsigned index : x)
685             {
686                 ret.push_back(&pm.rules[index]->rule);
687             }
688         }
689         return ret;
690     }
691 
692   private:
693     std::array<PerMethod, static_cast<size_t>(HttpVerb::Max)> perMethods;
694 
695     PerMethod notFoundRoutes;
696     PerMethod upgradeRoutes;
697     PerMethod methodNotAllowedRoutes;
698 
699     std::vector<std::unique_ptr<BaseRule>> allRules;
700 };
701 } // namespace crow
702