xref: /openbmc/bmcweb/http/routing.hpp (revision 40e9b92ec19acffb46f83a6e55b18974da5d708e)
1 // SPDX-License-Identifier: Apache-2.0
2 // SPDX-FileCopyrightText: Copyright OpenBMC Authors
3 #pragma once
4 
5 #include "async_resp.hpp"
6 #include "dbus_privileges.hpp"
7 #include "dbus_utility.hpp"
8 #include "error_messages.hpp"
9 #include "http_request.hpp"
10 #include "http_response.hpp"
11 #include "logging.hpp"
12 #include "privileges.hpp"
13 #include "routing/baserule.hpp"
14 #include "routing/dynamicrule.hpp"
15 #include "routing/sserule.hpp"
16 #include "routing/taggedrule.hpp"
17 #include "routing/websocketrule.hpp"
18 #include "sessions.hpp"
19 #include "utility.hpp"
20 #include "utils/dbus_utils.hpp"
21 #include "verb.hpp"
22 #include "websocket.hpp"
23 
24 #include <boost/container/flat_map.hpp>
25 #include <boost/container/small_vector.hpp>
26 
27 #include <algorithm>
28 #include <cerrno>
29 #include <cstdint>
30 #include <cstdlib>
31 #include <limits>
32 #include <memory>
33 #include <optional>
34 #include <string_view>
35 #include <tuple>
36 #include <utility>
37 #include <vector>
38 
39 namespace crow
40 {
41 
42 class Trie
43 {
44   public:
45     struct Node
46     {
47         unsigned ruleIndex = 0U;
48 
49         size_t stringParamChild = 0U;
50         size_t pathParamChild = 0U;
51 
52         using ChildMap = boost::container::flat_map<
53             std::string, unsigned, std::less<>,
54             boost::container::small_vector<std::pair<std::string, unsigned>,
55                                            1>>;
56         ChildMap children;
57 
58         bool isSimpleNode() const
59         {
60             return ruleIndex == 0 && stringParamChild == 0 &&
61                    pathParamChild == 0;
62         }
63     };
64 
65     Trie() : nodes(1) {}
66 
67   private:
68     void optimizeNode(Node& node)
69     {
70         if (node.stringParamChild != 0U)
71         {
72             optimizeNode(nodes[node.stringParamChild]);
73         }
74         if (node.pathParamChild != 0U)
75         {
76             optimizeNode(nodes[node.pathParamChild]);
77         }
78 
79         if (node.children.empty())
80         {
81             return;
82         }
83         while (true)
84         {
85             bool didMerge = false;
86             Node::ChildMap merged;
87             for (const Node::ChildMap::value_type& kv : node.children)
88             {
89                 Node& child = nodes[kv.second];
90                 if (child.isSimpleNode())
91                 {
92                     for (const Node::ChildMap::value_type& childKv :
93                          child.children)
94                     {
95                         merged[kv.first + childKv.first] = childKv.second;
96                         didMerge = true;
97                     }
98                 }
99                 else
100                 {
101                     merged[kv.first] = kv.second;
102                 }
103             }
104             node.children = std::move(merged);
105             if (!didMerge)
106             {
107                 break;
108             }
109         }
110 
111         for (const Node::ChildMap::value_type& kv : node.children)
112         {
113             optimizeNode(nodes[kv.second]);
114         }
115     }
116 
117     void optimize()
118     {
119         optimizeNode(head());
120     }
121 
122   public:
123     void validate()
124     {
125         optimize();
126     }
127 
128     void findRouteIndexesHelper(std::string_view reqUrl,
129                                 std::vector<unsigned>& routeIndexes,
130                                 const Node& node) const
131     {
132         for (const Node::ChildMap::value_type& kv : node.children)
133         {
134             const std::string& fragment = kv.first;
135             const Node& child = nodes[kv.second];
136             if (reqUrl.empty())
137             {
138                 if (child.ruleIndex != 0 && fragment != "/")
139                 {
140                     routeIndexes.push_back(child.ruleIndex);
141                 }
142                 findRouteIndexesHelper(reqUrl, routeIndexes, child);
143             }
144             else
145             {
146                 if (reqUrl.starts_with(fragment))
147                 {
148                     findRouteIndexesHelper(reqUrl.substr(fragment.size()),
149                                            routeIndexes, child);
150                 }
151             }
152         }
153     }
154 
155     void findRouteIndexes(const std::string& reqUrl,
156                           std::vector<unsigned>& routeIndexes) const
157     {
158         findRouteIndexesHelper(reqUrl, routeIndexes, head());
159     }
160 
161     struct FindResult
162     {
163         unsigned ruleIndex;
164         std::vector<std::string> params;
165     };
166 
167   private:
168     FindResult findHelper(const std::string_view reqUrl, const Node& node,
169                           std::vector<std::string>& params) const
170     {
171         if (reqUrl.empty())
172         {
173             return {node.ruleIndex, params};
174         }
175 
176         if (node.stringParamChild != 0U)
177         {
178             size_t epos = 0;
179             for (; epos < reqUrl.size(); epos++)
180             {
181                 if (reqUrl[epos] == '/')
182                 {
183                     break;
184                 }
185             }
186 
187             if (epos != 0)
188             {
189                 params.emplace_back(reqUrl.substr(0, epos));
190                 FindResult ret = findHelper(
191                     reqUrl.substr(epos), nodes[node.stringParamChild], params);
192                 if (ret.ruleIndex != 0U)
193                 {
194                     return {ret.ruleIndex, std::move(ret.params)};
195                 }
196                 params.pop_back();
197             }
198         }
199 
200         if (node.pathParamChild != 0U)
201         {
202             params.emplace_back(reqUrl);
203             FindResult ret = findHelper("", nodes[node.pathParamChild], params);
204             if (ret.ruleIndex != 0U)
205             {
206                 return {ret.ruleIndex, std::move(ret.params)};
207             }
208             params.pop_back();
209         }
210 
211         for (const Node::ChildMap::value_type& kv : node.children)
212         {
213             const std::string& fragment = kv.first;
214             const Node& child = nodes[kv.second];
215 
216             if (reqUrl.starts_with(fragment))
217             {
218                 FindResult ret =
219                     findHelper(reqUrl.substr(fragment.size()), child, params);
220                 if (ret.ruleIndex != 0U)
221                 {
222                     return {ret.ruleIndex, std::move(ret.params)};
223                 }
224             }
225         }
226 
227         return {0U, std::vector<std::string>()};
228     }
229 
230   public:
231     FindResult find(const std::string_view reqUrl) const
232     {
233         std::vector<std::string> start;
234         return findHelper(reqUrl, head(), start);
235     }
236 
237     void add(std::string_view urlIn, unsigned ruleIndex)
238     {
239         size_t idx = 0;
240 
241         std::string_view url = urlIn;
242 
243         while (!url.empty())
244         {
245             char c = url[0];
246             if (c == '<')
247             {
248                 bool found = false;
249                 for (const std::string_view str1 :
250                      {"<str>", "<string>", "<path>"})
251                 {
252                     if (!url.starts_with(str1))
253                     {
254                         continue;
255                     }
256                     found = true;
257                     Node& node = nodes[idx];
258                     size_t* param = &node.stringParamChild;
259                     if (str1 == "<path>")
260                     {
261                         param = &node.pathParamChild;
262                     }
263                     if (*param == 0U)
264                     {
265                         *param = newNode();
266                     }
267                     idx = *param;
268 
269                     url.remove_prefix(str1.size());
270                     break;
271                 }
272                 if (found)
273                 {
274                     continue;
275                 }
276 
277                 BMCWEB_LOG_CRITICAL("Can't find tag for {}", urlIn);
278                 return;
279             }
280             std::string piece(&c, 1);
281             if (!nodes[idx].children.contains(piece))
282             {
283                 unsigned newNodeIdx = newNode();
284                 nodes[idx].children.emplace(piece, newNodeIdx);
285             }
286             idx = nodes[idx].children[piece];
287             url.remove_prefix(1);
288         }
289         Node& node = nodes[idx];
290         if (node.ruleIndex != 0U)
291         {
292             BMCWEB_LOG_CRITICAL("handler already exists for \"{}\"", urlIn);
293             throw std::runtime_error(
294                 std::format("handler already exists for \"{}\"", urlIn));
295         }
296         node.ruleIndex = ruleIndex;
297     }
298 
299   private:
300     void debugNodePrint(Node& n, size_t level)
301     {
302         std::string spaces(level, ' ');
303         if (n.stringParamChild != 0U)
304         {
305             BMCWEB_LOG_DEBUG("{}<str>", spaces);
306             debugNodePrint(nodes[n.stringParamChild], level + 5);
307         }
308         if (n.pathParamChild != 0U)
309         {
310             BMCWEB_LOG_DEBUG("{} <path>", spaces);
311             debugNodePrint(nodes[n.pathParamChild], level + 6);
312         }
313         for (const Node::ChildMap::value_type& kv : n.children)
314         {
315             BMCWEB_LOG_DEBUG("{}{}", spaces, kv.first);
316             debugNodePrint(nodes[kv.second], level + kv.first.size());
317         }
318     }
319 
320   public:
321     void debugPrint()
322     {
323         debugNodePrint(head(), 0U);
324     }
325 
326   private:
327     const Node& head() const
328     {
329         return nodes.front();
330     }
331 
332     Node& head()
333     {
334         return nodes.front();
335     }
336 
337     unsigned newNode()
338     {
339         nodes.resize(nodes.size() + 1);
340         return static_cast<unsigned>(nodes.size() - 1);
341     }
342 
343     std::vector<Node> nodes;
344 };
345 
346 class Router
347 {
348   public:
349     Router() = default;
350 
351     DynamicRule& newRuleDynamic(const std::string& rule)
352     {
353         std::unique_ptr<DynamicRule> ruleObject =
354             std::make_unique<DynamicRule>(rule);
355         DynamicRule* ptr = ruleObject.get();
356         allRules.emplace_back(std::move(ruleObject));
357 
358         return *ptr;
359     }
360 
361     template <uint64_t NumArgs>
362     auto& newRuleTagged(const std::string& rule)
363     {
364         if constexpr (NumArgs == 0)
365         {
366             using RuleT = TaggedRule<>;
367             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
368             RuleT* ptr = ruleObject.get();
369             allRules.emplace_back(std::move(ruleObject));
370             return *ptr;
371         }
372         else if constexpr (NumArgs == 1)
373         {
374             using RuleT = TaggedRule<std::string>;
375             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
376             RuleT* ptr = ruleObject.get();
377             allRules.emplace_back(std::move(ruleObject));
378             return *ptr;
379         }
380         else if constexpr (NumArgs == 2)
381         {
382             using RuleT = TaggedRule<std::string, std::string>;
383             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
384             RuleT* ptr = ruleObject.get();
385             allRules.emplace_back(std::move(ruleObject));
386             return *ptr;
387         }
388         else if constexpr (NumArgs == 3)
389         {
390             using RuleT = TaggedRule<std::string, std::string, std::string>;
391             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
392             RuleT* ptr = ruleObject.get();
393             allRules.emplace_back(std::move(ruleObject));
394             return *ptr;
395         }
396         else if constexpr (NumArgs == 4)
397         {
398             using RuleT =
399                 TaggedRule<std::string, std::string, std::string, std::string>;
400             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
401             RuleT* ptr = ruleObject.get();
402             allRules.emplace_back(std::move(ruleObject));
403             return *ptr;
404         }
405         else
406         {
407             using RuleT = TaggedRule<std::string, std::string, std::string,
408                                      std::string, std::string>;
409             std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
410             RuleT* ptr = ruleObject.get();
411             allRules.emplace_back(std::move(ruleObject));
412             return *ptr;
413         }
414         static_assert(NumArgs <= 5, "Max number of args supported is 5");
415     }
416 
417     struct PerMethod
418     {
419         std::vector<BaseRule*> rules;
420         Trie trie;
421         // rule index 0 has special meaning; preallocate it to avoid
422         // duplication.
423         PerMethod() : rules(1) {}
424 
425         void internalAdd(std::string_view rule, BaseRule* ruleObject)
426         {
427             rules.emplace_back(ruleObject);
428             trie.add(rule, static_cast<unsigned>(rules.size() - 1U));
429             // directory case:
430             //   request to `/about' url matches `/about/' rule
431             if (rule.size() > 2 && rule.back() == '/')
432             {
433                 trie.add(rule.substr(0, rule.size() - 1),
434                          static_cast<unsigned>(rules.size() - 1));
435             }
436         }
437     };
438 
439     void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
440     {
441         if (ruleObject == nullptr)
442         {
443             return;
444         }
445         for (size_t method = 0; method <= maxVerbIndex; method++)
446         {
447             size_t methodBit = 1 << method;
448             if ((ruleObject->methodsBitfield & methodBit) > 0U)
449             {
450                 perMethods[method].internalAdd(rule, ruleObject);
451             }
452         }
453 
454         if (ruleObject->isNotFound)
455         {
456             notFoundRoutes.internalAdd(rule, ruleObject);
457         }
458 
459         if (ruleObject->isMethodNotAllowed)
460         {
461             methodNotAllowedRoutes.internalAdd(rule, ruleObject);
462         }
463 
464         if (ruleObject->isUpgrade)
465         {
466             upgradeRoutes.internalAdd(rule, ruleObject);
467         }
468     }
469 
470     void validate()
471     {
472         for (std::unique_ptr<BaseRule>& rule : allRules)
473         {
474             if (rule)
475             {
476                 std::unique_ptr<BaseRule> upgraded = rule->upgrade();
477                 if (upgraded)
478                 {
479                     rule = std::move(upgraded);
480                 }
481                 rule->validate();
482                 internalAddRuleObject(rule->rule, rule.get());
483             }
484         }
485         for (PerMethod& perMethod : perMethods)
486         {
487             perMethod.trie.validate();
488         }
489     }
490 
491     struct FindRoute
492     {
493         BaseRule* rule = nullptr;
494         std::vector<std::string> params;
495     };
496 
497     struct FindRouteResponse
498     {
499         std::string allowHeader;
500         FindRoute route;
501     };
502 
503     static FindRoute findRouteByPerMethod(std::string_view url,
504                                           const PerMethod& perMethod)
505     {
506         FindRoute route;
507 
508         Trie::FindResult found = perMethod.trie.find(url);
509         if (found.ruleIndex >= perMethod.rules.size())
510         {
511             throw std::runtime_error("Trie internal structure corrupted!");
512         }
513         // Found a 404 route, switch that in
514         if (found.ruleIndex != 0U)
515         {
516             route.rule = perMethod.rules[found.ruleIndex];
517             route.params = std::move(found.params);
518         }
519         return route;
520     }
521 
522     FindRouteResponse findRoute(const Request& req) const
523     {
524         FindRouteResponse findRoute;
525 
526         // Check to see if this url exists at any verb
527         for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
528              perMethodIndex++)
529         {
530             // Make sure it's safe to deference the array at that index
531             static_assert(
532                 maxVerbIndex < std::tuple_size_v<decltype(perMethods)>);
533             FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
534                                                    perMethods[perMethodIndex]);
535             if (route.rule == nullptr)
536             {
537                 continue;
538             }
539             if (!findRoute.allowHeader.empty())
540             {
541                 findRoute.allowHeader += ", ";
542             }
543             HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
544             findRoute.allowHeader += httpVerbToString(thisVerb);
545         }
546 
547         std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
548         if (!verb)
549         {
550             return findRoute;
551         }
552         size_t reqMethodIndex = static_cast<size_t>(*verb);
553         if (reqMethodIndex >= perMethods.size())
554         {
555             return findRoute;
556         }
557 
558         FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
559                                                perMethods[reqMethodIndex]);
560         if (route.rule != nullptr)
561         {
562             findRoute.route = route;
563         }
564 
565         return findRoute;
566     }
567 
568     template <typename Adaptor>
569     void handleUpgrade(const std::shared_ptr<Request>& req,
570                        const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
571                        Adaptor&& adaptor)
572     {
573         PerMethod& perMethod = upgradeRoutes;
574         Trie& trie = perMethod.trie;
575         std::vector<BaseRule*>& rules = perMethod.rules;
576 
577         Trie::FindResult found = trie.find(req->url().encoded_path());
578         unsigned ruleIndex = found.ruleIndex;
579         if (ruleIndex == 0U)
580         {
581             BMCWEB_LOG_DEBUG("Cannot match rules {}",
582                              req->url().encoded_path());
583             asyncResp->res.result(boost::beast::http::status::not_found);
584             return;
585         }
586 
587         if (ruleIndex >= rules.size())
588         {
589             throw std::runtime_error("Trie internal structure corrupted!");
590         }
591 
592         BaseRule& rule = *rules[ruleIndex];
593 
594         BMCWEB_LOG_DEBUG("Matched rule (upgrade) '{}'", rule.rule);
595 
596         // TODO(ed) This should be able to use std::bind_front, but it doesn't
597         // appear to work with the std::move on adaptor.
598         validatePrivilege(
599             req, asyncResp, rule,
600             [req, &rule, asyncResp,
601              adaptor = std::forward<Adaptor>(adaptor)]() mutable {
602                 rule.handleUpgrade(*req, asyncResp, std::move(adaptor));
603             });
604     }
605 
606     void handle(const std::shared_ptr<Request>& req,
607                 const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
608     {
609         FindRouteResponse foundRoute = findRoute(*req);
610 
611         if (foundRoute.route.rule == nullptr)
612         {
613             // Couldn't find a normal route with any verb, try looking for a 404
614             // route
615             if (foundRoute.allowHeader.empty())
616             {
617                 foundRoute.route = findRouteByPerMethod(
618                     req->url().encoded_path(), notFoundRoutes);
619             }
620             else
621             {
622                 // See if we have a method not allowed (405) handler
623                 foundRoute.route = findRouteByPerMethod(
624                     req->url().encoded_path(), methodNotAllowedRoutes);
625             }
626         }
627 
628         // Fill in the allow header if it's valid
629         if (!foundRoute.allowHeader.empty())
630         {
631             asyncResp->res.addHeader(boost::beast::http::field::allow,
632                                      foundRoute.allowHeader);
633         }
634 
635         // If we couldn't find a real route or a 404 route, return a generic
636         // response
637         if (foundRoute.route.rule == nullptr)
638         {
639             if (foundRoute.allowHeader.empty())
640             {
641                 asyncResp->res.result(boost::beast::http::status::not_found);
642             }
643             else
644             {
645                 asyncResp->res.result(
646                     boost::beast::http::status::method_not_allowed);
647             }
648             return;
649         }
650 
651         BaseRule& rule = *foundRoute.route.rule;
652         std::vector<std::string> params = std::move(foundRoute.route.params);
653 
654         BMCWEB_LOG_DEBUG("Matched rule '{}' {} / {}", rule.rule,
655                          req->methodString(), rule.getMethods());
656 
657         if (req->session == nullptr)
658         {
659             rule.handle(*req, asyncResp, params);
660             return;
661         }
662         validatePrivilege(
663             req, asyncResp, rule,
664             [req, asyncResp, &rule, params = std::move(params)]() {
665                 rule.handle(*req, asyncResp, params);
666             });
667     }
668 
669     void debugPrint()
670     {
671         for (size_t i = 0; i < perMethods.size(); i++)
672         {
673             BMCWEB_LOG_DEBUG("{}", httpVerbToString(static_cast<HttpVerb>(i)));
674             perMethods[i].trie.debugPrint();
675         }
676     }
677 
678     std::vector<const std::string*> getRoutes(const std::string& parent)
679     {
680         std::vector<const std::string*> ret;
681 
682         for (const PerMethod& pm : perMethods)
683         {
684             std::vector<unsigned> x;
685             pm.trie.findRouteIndexes(parent, x);
686             for (unsigned index : x)
687             {
688                 ret.push_back(&pm.rules[index]->rule);
689             }
690         }
691         return ret;
692     }
693 
694   private:
695     std::array<PerMethod, static_cast<size_t>(HttpVerb::Max)> perMethods;
696 
697     PerMethod notFoundRoutes;
698     PerMethod upgradeRoutes;
699     PerMethod methodNotAllowedRoutes;
700 
701     std::vector<std::unique_ptr<BaseRule>> allRules;
702 };
703 } // namespace crow
704