xref: /openbmc/bmcweb/http/routing.hpp (revision d78572018fc2022091ff8b8eb5a7fef2172ba3d6)
1 // SPDX-License-Identifier: Apache-2.0
2 // SPDX-FileCopyrightText: Copyright OpenBMC Authors
3 #pragma once
4 
5 #include "async_resp.hpp"
6 #include "dbus_privileges.hpp"
7 #include "http_request.hpp"
8 #include "http_response.hpp"
9 #include "logging.hpp"
10 #include "routing/baserule.hpp"
11 #include "routing/dynamicrule.hpp"
12 #include "routing/taggedrule.hpp"
13 #include "verb.hpp"
14 
15 #include <boost/beast/http/field.hpp>
16 #include <boost/beast/http/status.hpp>
17 #include <boost/container/flat_map.hpp>
18 #include <boost/container/small_vector.hpp>
19 
20 #include <algorithm>
21 #include <array>
22 #include <cerrno>
23 #include <cstdint>
24 #include <cstdlib>
25 #include <format>
26 #include <functional>
27 #include <memory>
28 #include <optional>
29 #include <stdexcept>
30 #include <string>
31 #include <string_view>
32 #include <tuple>
33 #include <utility>
34 #include <vector>
35 
36 namespace crow
37 {
38 
39 class Trie
40 {
41   public:
42     struct Node
43     {
44         unsigned ruleIndex = 0U;
45 
46         size_t stringParamChild = 0U;
47         size_t pathParamChild = 0U;
48 
49         using ChildMap = boost::container::flat_map<
50             std::string, unsigned, std::less<>,
51             boost::container::small_vector<std::pair<std::string, unsigned>,
52                                            1>>;
53         ChildMap children;
54 
isSimpleNodecrow::Trie::Node55         bool isSimpleNode() const
56         {
57             return ruleIndex == 0 && stringParamChild == 0 &&
58                    pathParamChild == 0;
59         }
60     };
61 
Trie()62     Trie() : nodes(1) {}
63 
64   private:
optimizeNode(Node & node)65     void optimizeNode(Node& node)
66     {
67         if (node.stringParamChild != 0U)
68         {
69             optimizeNode(nodes[node.stringParamChild]);
70         }
71         if (node.pathParamChild != 0U)
72         {
73             optimizeNode(nodes[node.pathParamChild]);
74         }
75 
76         if (node.children.empty())
77         {
78             return;
79         }
80         while (true)
81         {
82             bool didMerge = false;
83             Node::ChildMap merged;
84             for (const Node::ChildMap::value_type& kv : node.children)
85             {
86                 Node& child = nodes[kv.second];
87                 if (child.isSimpleNode())
88                 {
89                     for (const Node::ChildMap::value_type& childKv :
90                          child.children)
91                     {
92                         merged[kv.first + childKv.first] = childKv.second;
93                         didMerge = true;
94                     }
95                 }
96                 else
97                 {
98                     merged[kv.first] = kv.second;
99                 }
100             }
101             node.children = std::move(merged);
102             if (!didMerge)
103             {
104                 break;
105             }
106         }
107 
108         for (const Node::ChildMap::value_type& kv : node.children)
109         {
110             optimizeNode(nodes[kv.second]);
111         }
112     }
113 
optimize()114     void optimize()
115     {
116         optimizeNode(head());
117     }
118 
119   public:
validate()120     void validate()
121     {
122         optimize();
123     }
124 
findRouteIndexesHelper(std::string_view reqUrl,std::vector<unsigned> & routeIndexes,const Node & node) const125     void findRouteIndexesHelper(std::string_view reqUrl,
126                                 std::vector<unsigned>& routeIndexes,
127                                 const Node& node) const
128     {
129         for (const Node::ChildMap::value_type& kv : node.children)
130         {
131             const std::string& fragment = kv.first;
132             const Node& child = nodes[kv.second];
133             if (reqUrl.empty())
134             {
135                 if (child.ruleIndex != 0 && fragment != "/")
136                 {
137                     routeIndexes.push_back(child.ruleIndex);
138                 }
139                 findRouteIndexesHelper(reqUrl, routeIndexes, child);
140             }
141             else
142             {
143                 if (reqUrl.starts_with(fragment))
144                 {
145                     findRouteIndexesHelper(reqUrl.substr(fragment.size()),
146                                            routeIndexes, child);
147                 }
148             }
149         }
150     }
151 
findRouteIndexes(const std::string & reqUrl,std::vector<unsigned> & routeIndexes) const152     void findRouteIndexes(const std::string& reqUrl,
153                           std::vector<unsigned>& routeIndexes) const
154     {
155         findRouteIndexesHelper(reqUrl, routeIndexes, head());
156     }
157 
158     struct FindResult
159     {
160         unsigned ruleIndex;
161         std::vector<std::string> params;
162     };
163 
164   private:
findHelper(const std::string_view reqUrl,const Node & node,std::vector<std::string> & params) const165     FindResult findHelper(const std::string_view reqUrl, const Node& node,
166                           std::vector<std::string>& params) const
167     {
168         if (reqUrl.empty())
169         {
170             return {node.ruleIndex, params};
171         }
172 
173         if (node.stringParamChild != 0U)
174         {
175             size_t epos = 0;
176             for (; epos < reqUrl.size(); epos++)
177             {
178                 if (reqUrl[epos] == '/')
179                 {
180                     break;
181                 }
182             }
183 
184             if (epos != 0)
185             {
186                 params.emplace_back(reqUrl.substr(0, epos));
187                 FindResult ret = findHelper(
188                     reqUrl.substr(epos), nodes[node.stringParamChild], params);
189                 if (ret.ruleIndex != 0U)
190                 {
191                     return {ret.ruleIndex, std::move(ret.params)};
192                 }
193                 params.pop_back();
194             }
195         }
196 
197         if (node.pathParamChild != 0U)
198         {
199             params.emplace_back(reqUrl);
200             FindResult ret = findHelper("", nodes[node.pathParamChild], params);
201             if (ret.ruleIndex != 0U)
202             {
203                 return {ret.ruleIndex, std::move(ret.params)};
204             }
205             params.pop_back();
206         }
207 
208         for (const Node::ChildMap::value_type& kv : node.children)
209         {
210             const std::string& fragment = kv.first;
211             const Node& child = nodes[kv.second];
212 
213             if (reqUrl.starts_with(fragment))
214             {
215                 FindResult ret =
216                     findHelper(reqUrl.substr(fragment.size()), child, params);
217                 if (ret.ruleIndex != 0U)
218                 {
219                     return {ret.ruleIndex, std::move(ret.params)};
220                 }
221             }
222         }
223 
224         return {0U, std::vector<std::string>()};
225     }
226 
227   public:
find(const std::string_view reqUrl) const228     FindResult find(const std::string_view reqUrl) const
229     {
230         std::vector<std::string> start;
231         return findHelper(reqUrl, head(), start);
232     }
233 
add(std::string_view urlIn,unsigned ruleIndex)234     void add(std::string_view urlIn, unsigned ruleIndex)
235     {
236         size_t idx = 0;
237 
238         std::string_view url = urlIn;
239 
240         while (!url.empty())
241         {
242             char c = url[0];
243             if (c == '<')
244             {
245                 bool found = false;
246                 for (const std::string_view str1 :
247                      {"<str>", "<string>", "<path>"})
248                 {
249                     if (!url.starts_with(str1))
250                     {
251                         continue;
252                     }
253                     found = true;
254                     Node& node = nodes[idx];
255                     size_t* param = &node.stringParamChild;
256                     if (str1 == "<path>")
257                     {
258                         param = &node.pathParamChild;
259                     }
260                     if (*param == 0U)
261                     {
262                         *param = newNode();
263                     }
264                     idx = *param;
265 
266                     url.remove_prefix(str1.size());
267                     break;
268                 }
269                 if (found)
270                 {
271                     continue;
272                 }
273 
274                 BMCWEB_LOG_CRITICAL("Can't find tag for {}", urlIn);
275                 return;
276             }
277             std::string piece(&c, 1);
278             if (!nodes[idx].children.contains(piece))
279             {
280                 unsigned newNodeIdx = newNode();
281                 nodes[idx].children.emplace(piece, newNodeIdx);
282             }
283             idx = nodes[idx].children[piece];
284             url.remove_prefix(1);
285         }
286         Node& node = nodes[idx];
287         if (node.ruleIndex != 0U)
288         {
289             BMCWEB_LOG_CRITICAL("handler already exists for \"{}\"", urlIn);
290             throw std::runtime_error(
291                 std::format("handler already exists for \"{}\"", urlIn));
292         }
293         node.ruleIndex = ruleIndex;
294     }
295 
296   private:
debugNodePrint(Node & n,size_t level)297     void debugNodePrint(Node& n, size_t level)
298     {
299         std::string spaces(level, ' ');
300         if (n.stringParamChild != 0U)
301         {
302             BMCWEB_LOG_DEBUG("{}<str>", spaces);
303             debugNodePrint(nodes[n.stringParamChild], level + 5);
304         }
305         if (n.pathParamChild != 0U)
306         {
307             BMCWEB_LOG_DEBUG("{} <path>", spaces);
308             debugNodePrint(nodes[n.pathParamChild], level + 6);
309         }
310         for (const Node::ChildMap::value_type& kv : n.children)
311         {
312             BMCWEB_LOG_DEBUG("{}{}", spaces, kv.first);
313             debugNodePrint(nodes[kv.second], level + kv.first.size());
314         }
315     }
316 
317   public:
debugPrint()318     void debugPrint()
319     {
320         debugNodePrint(head(), 0U);
321     }
322 
323   private:
head() const324     const Node& head() const
325     {
326         return nodes.front();
327     }
328 
head()329     Node& head()
330     {
331         return nodes.front();
332     }
333 
newNode()334     unsigned newNode()
335     {
336         nodes.resize(nodes.size() + 1);
337         return static_cast<unsigned>(nodes.size() - 1);
338     }
339 
340     std::vector<Node> nodes;
341 };
342 
343 class Router
344 {
345   public:
346     Router() = default;
347 
newRuleDynamic(const std::string & rule)348     DynamicRule& newRuleDynamic(const std::string& rule)
349     {
350         std::unique_ptr<DynamicRule> ruleObject =
351             std::make_unique<DynamicRule>(rule);
352         DynamicRule* ptr = ruleObject.get();
353         allRules.emplace_back(std::move(ruleObject));
354 
355         return *ptr;
356     }
357 
358     template <uint64_t NumArgs>
newRuleTagged(const std::string & rule)359     auto& newRuleTagged(const std::string& rule)
360     {
361         if constexpr (NumArgs == 0)
362         {
363             using RuleT = TaggedRule<>;
364             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
365             RuleT* ptr = ruleObject.get();
366             allRules.emplace_back(std::move(ruleObject));
367             return *ptr;
368         }
369         else if constexpr (NumArgs == 1)
370         {
371             using RuleT = TaggedRule<std::string>;
372             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
373             RuleT* ptr = ruleObject.get();
374             allRules.emplace_back(std::move(ruleObject));
375             return *ptr;
376         }
377         else if constexpr (NumArgs == 2)
378         {
379             using RuleT = TaggedRule<std::string, std::string>;
380             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
381             RuleT* ptr = ruleObject.get();
382             allRules.emplace_back(std::move(ruleObject));
383             return *ptr;
384         }
385         else if constexpr (NumArgs == 3)
386         {
387             using RuleT = TaggedRule<std::string, std::string, std::string>;
388             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
389             RuleT* ptr = ruleObject.get();
390             allRules.emplace_back(std::move(ruleObject));
391             return *ptr;
392         }
393         else if constexpr (NumArgs == 4)
394         {
395             using RuleT =
396                 TaggedRule<std::string, std::string, std::string, std::string>;
397             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
398             RuleT* ptr = ruleObject.get();
399             allRules.emplace_back(std::move(ruleObject));
400             return *ptr;
401         }
402         else
403         {
404             using RuleT = TaggedRule<std::string, std::string, std::string,
405                                      std::string, std::string>;
406             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
407             RuleT* ptr = ruleObject.get();
408             allRules.emplace_back(std::move(ruleObject));
409             return *ptr;
410         }
411         static_assert(NumArgs <= 5, "Max number of args supported is 5");
412     }
413 
414     struct PerMethod
415     {
416         std::vector<BaseRule*> rules;
417         Trie trie;
418         // rule index 0 has special meaning; preallocate it to avoid
419         // duplication.
PerMethodcrow::Router::PerMethod420         PerMethod() : rules(1) {}
421 
internalAddcrow::Router::PerMethod422         void internalAdd(std::string_view rule, BaseRule* ruleObject)
423         {
424             rules.emplace_back(ruleObject);
425             trie.add(rule, static_cast<unsigned>(rules.size() - 1U));
426             // directory case:
427             //   request to `/about' url matches `/about/' rule
428             if (rule.size() > 2 && rule.back() == '/')
429             {
430                 trie.add(rule.substr(0, rule.size() - 1),
431                          static_cast<unsigned>(rules.size() - 1));
432             }
433         }
434     };
435 
internalAddRuleObject(const std::string & rule,BaseRule * ruleObject)436     void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
437     {
438         if (ruleObject == nullptr)
439         {
440             return;
441         }
442         for (size_t method = 0; method <= maxVerbIndex; method++)
443         {
444             size_t methodBit = 1 << method;
445             if ((ruleObject->methodsBitfield & methodBit) > 0U)
446             {
447                 perMethods[method].internalAdd(rule, ruleObject);
448             }
449         }
450 
451         if (ruleObject->isNotFound)
452         {
453             notFoundRoutes.internalAdd(rule, ruleObject);
454         }
455 
456         if (ruleObject->isMethodNotAllowed)
457         {
458             methodNotAllowedRoutes.internalAdd(rule, ruleObject);
459         }
460 
461         if (ruleObject->isUpgrade)
462         {
463             upgradeRoutes.internalAdd(rule, ruleObject);
464         }
465     }
466 
validate()467     void validate()
468     {
469         for (std::unique_ptr<BaseRule>& rule : allRules)
470         {
471             if (rule)
472             {
473                 std::unique_ptr<BaseRule> upgraded = rule->upgrade();
474                 if (upgraded)
475                 {
476                     rule = std::move(upgraded);
477                 }
478                 rule->validate();
479                 internalAddRuleObject(rule->rule, rule.get());
480             }
481         }
482         for (PerMethod& perMethod : perMethods)
483         {
484             perMethod.trie.validate();
485         }
486     }
487 
488     struct FindRoute
489     {
490         BaseRule* rule = nullptr;
491         std::vector<std::string> params;
492     };
493 
494     struct FindRouteResponse
495     {
496         std::string allowHeader;
497         FindRoute route;
498     };
499 
findRouteByPerMethod(std::string_view url,const PerMethod & perMethod)500     static FindRoute findRouteByPerMethod(std::string_view url,
501                                           const PerMethod& perMethod)
502     {
503         FindRoute route;
504 
505         Trie::FindResult found = perMethod.trie.find(url);
506         if (found.ruleIndex >= perMethod.rules.size())
507         {
508             throw std::runtime_error("Trie internal structure corrupted!");
509         }
510         // Found a 404 route, switch that in
511         if (found.ruleIndex != 0U)
512         {
513             route.rule = perMethod.rules[found.ruleIndex];
514             route.params = std::move(found.params);
515         }
516         return route;
517     }
518 
findRoute(const Request & req) const519     FindRouteResponse findRoute(const Request& req) const
520     {
521         FindRouteResponse findRoute;
522 
523         // Check to see if this url exists at any verb
524         for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
525              perMethodIndex++)
526         {
527             // Make sure it's safe to deference the array at that index
528             static_assert(
529                 maxVerbIndex < std::tuple_size_v<decltype(perMethods)>);
530             FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
531                                                    perMethods[perMethodIndex]);
532             if (route.rule == nullptr)
533             {
534                 continue;
535             }
536             if (!findRoute.allowHeader.empty())
537             {
538                 findRoute.allowHeader += ", ";
539             }
540             HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
541             findRoute.allowHeader += httpVerbToString(thisVerb);
542         }
543 
544         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
545         if (!verb)
546         {
547             return findRoute;
548         }
549         size_t reqMethodIndex = static_cast<size_t>(*verb);
550         if (reqMethodIndex >= perMethods.size())
551         {
552             return findRoute;
553         }
554 
555         FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
556                                                perMethods[reqMethodIndex]);
557         if (route.rule != nullptr)
558         {
559             findRoute.route = route;
560         }
561 
562         return findRoute;
563     }
564 
565     template <typename Adaptor>
handleUpgrade(const std::shared_ptr<Request> & req,const std::shared_ptr<bmcweb::AsyncResp> & asyncResp,Adaptor && adaptor)566     void handleUpgrade(const std::shared_ptr<Request>& req,
567                        const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
568                        Adaptor&& adaptor)
569     {
570         PerMethod& perMethod = upgradeRoutes;
571         Trie& trie = perMethod.trie;
572         std::vector<BaseRule*>& rules = perMethod.rules;
573 
574         Trie::FindResult found = trie.find(req->url().encoded_path());
575         unsigned ruleIndex = found.ruleIndex;
576         if (ruleIndex == 0U)
577         {
578             BMCWEB_LOG_DEBUG("Cannot match rules {}",
579                              req->url().encoded_path());
580             asyncResp->res.result(boost::beast::http::status::not_found);
581             return;
582         }
583 
584         if (ruleIndex >= rules.size())
585         {
586             throw std::runtime_error("Trie internal structure corrupted!");
587         }
588 
589         BaseRule& rule = *rules[ruleIndex];
590 
591         BMCWEB_LOG_DEBUG("Matched rule (upgrade) '{}'", rule.rule);
592 
593         // TODO(ed) This should be able to use std::bind_front, but it doesn't
594         // appear to work with the std::move on adaptor.
595         validatePrivilege(
596             req, asyncResp, rule,
597             [req, &rule, asyncResp,
598              adaptor = std::forward<Adaptor>(adaptor)]() mutable {
599                 rule.handleUpgrade(*req, asyncResp, std::move(adaptor));
600             });
601     }
602 
handle(const std::shared_ptr<Request> & req,const std::shared_ptr<bmcweb::AsyncResp> & asyncResp)603     void handle(const std::shared_ptr<Request>& req,
604                 const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
605     {
606         FindRouteResponse foundRoute = findRoute(*req);
607 
608         if (foundRoute.route.rule == nullptr)
609         {
610             // Couldn't find a normal route with any verb, try looking for a 404
611             // route
612             if (foundRoute.allowHeader.empty())
613             {
614                 foundRoute.route = findRouteByPerMethod(
615                     req->url().encoded_path(), notFoundRoutes);
616             }
617             else
618             {
619                 // See if we have a method not allowed (405) handler
620                 foundRoute.route = findRouteByPerMethod(
621                     req->url().encoded_path(), methodNotAllowedRoutes);
622             }
623         }
624 
625         // Fill in the allow header if it's valid
626         if (!foundRoute.allowHeader.empty())
627         {
628             asyncResp->res.addHeader(boost::beast::http::field::allow,
629                                      foundRoute.allowHeader);
630         }
631 
632         // If we couldn't find a real route or a 404 route, return a generic
633         // response
634         if (foundRoute.route.rule == nullptr)
635         {
636             if (foundRoute.allowHeader.empty())
637             {
638                 asyncResp->res.result(boost::beast::http::status::not_found);
639             }
640             else
641             {
642                 asyncResp->res.result(
643                     boost::beast::http::status::method_not_allowed);
644             }
645             return;
646         }
647 
648         BaseRule& rule = *foundRoute.route.rule;
649         std::vector<std::string> params = std::move(foundRoute.route.params);
650 
651         BMCWEB_LOG_DEBUG("Matched rule '{}' {} / {}", rule.rule,
652                          req->methodString(), rule.getMethods());
653 
654         if (req->session == nullptr)
655         {
656             rule.handle(*req, asyncResp, params);
657             return;
658         }
659         validatePrivilege(
660             req, asyncResp, rule,
661             [req, asyncResp, &rule, params = std::move(params)]() {
662                 rule.handle(*req, asyncResp, params);
663             });
664     }
665 
debugPrint()666     void debugPrint()
667     {
668         for (size_t i = 0; i < perMethods.size(); i++)
669         {
670             BMCWEB_LOG_DEBUG("{}", httpVerbToString(static_cast<HttpVerb>(i)));
671             perMethods[i].trie.debugPrint();
672         }
673     }
674 
getRoutes(const std::string & parent)675     std::vector<const std::string*> getRoutes(const std::string& parent)
676     {
677         std::vector<const std::string*> ret;
678 
679         for (const PerMethod& pm : perMethods)
680         {
681             std::vector<unsigned> x;
682             pm.trie.findRouteIndexes(parent, x);
683             for (unsigned index : x)
684             {
685                 ret.push_back(&pm.rules[index]->rule);
686             }
687         }
688         return ret;
689     }
690 
691   private:
692     std::array<PerMethod, static_cast<size_t>(HttpVerb::Max)> perMethods;
693 
694     PerMethod notFoundRoutes;
695     PerMethod upgradeRoutes;
696     PerMethod methodNotAllowedRoutes;
697 
698     std::vector<std::unique_ptr<BaseRule>> allRules;
699 };
700 } // namespace crow
701