xref: /openbmc/bmcweb/http/utility.hpp (revision 80d2ef31c343ec9c562778848b9d62fb6ce57c07)
1 // SPDX-License-Identifier: Apache-2.0
2 // SPDX-FileCopyrightText: Copyright OpenBMC Authors
3 #pragma once
4 
5 #include "bmcweb_config.h"
6 
7 #include <sys/types.h>
8 
9 #include <boost/url/segments_view.hpp>
10 #include <boost/url/url.hpp>
11 #include <boost/url/url_view_base.hpp>
12 #include <nlohmann/json.hpp>
13 
14 #include <array>
15 #include <bit>
16 #include <concepts>
17 #include <cstddef>
18 #include <cstdint>
19 #include <ctime>
20 #include <functional>
21 #include <initializer_list>
22 #include <limits>
23 #include <string>
24 #include <string_view>
25 #include <type_traits>
26 #include <utility>
27 #include <variant>
28 
29 namespace crow
30 {
31 namespace utility
32 {
33 
getParameterTag(std::string_view url)34 constexpr uint64_t getParameterTag(std::string_view url)
35 {
36     uint64_t tagValue = 0;
37     size_t urlSegmentIndex = std::string_view::npos;
38 
39     for (size_t urlIndex = 0; urlIndex < url.size(); urlIndex++)
40     {
41         char character = url[urlIndex];
42         if (character == '<')
43         {
44             if (urlSegmentIndex != std::string_view::npos)
45             {
46                 return 0;
47             }
48             urlSegmentIndex = urlIndex;
49         }
50         if (character == '>')
51         {
52             if (urlSegmentIndex == std::string_view::npos)
53             {
54                 return 0;
55             }
56             std::string_view tag =
57                 url.substr(urlSegmentIndex, urlIndex + 1 - urlSegmentIndex);
58 
59             if (tag == "<str>" || tag == "<string>")
60             {
61                 tagValue++;
62             }
63             if (tag == "<path>")
64             {
65                 tagValue++;
66             }
67             urlSegmentIndex = std::string_view::npos;
68         }
69     }
70     if (urlSegmentIndex != std::string_view::npos)
71     {
72         return 0;
73     }
74     return tagValue;
75 }
76 
77 constexpr static std::array<char, 64> base64key = {
78     'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
79     'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
80     'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
81     'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
82     '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
83 
84 static constexpr char nop = static_cast<char>(-1);
getDecodeTable(bool urlSafe)85 constexpr std::array<char, 256> getDecodeTable(bool urlSafe)
86 {
87     std::array<char, 256> decodeTable{};
88     decodeTable.fill(nop);
89 
90     for (size_t index = 0; index < base64key.size(); index++)
91     {
92         char character = base64key[index];
93         decodeTable[std::bit_cast<uint8_t>(character)] =
94             static_cast<char>(index);
95     }
96 
97     if (urlSafe)
98     {
99         // Urlsafe decode tables replace the last two characters with - and _
100         decodeTable['+'] = nop;
101         decodeTable['/'] = nop;
102         decodeTable['-'] = 62;
103         decodeTable['_'] = 63;
104     }
105 
106     return decodeTable;
107 }
108 
109 class Base64Encoder
110 {
111     char overflow1 = '\0';
112     char overflow2 = '\0';
113     uint8_t overflowCount = 0;
114 
115     // Takes 3 ascii chars, and encodes them as 4 base64 chars
encodeTriple(char first,char second,char third,std::string & output)116     static void encodeTriple(char first, char second, char third,
117                              std::string& output)
118     {
119         size_t keyIndex = 0;
120 
121         keyIndex = static_cast<size_t>(first & 0xFC) >> 2;
122         output += base64key[keyIndex];
123 
124         keyIndex = static_cast<size_t>(first & 0x03) << 4;
125         keyIndex += static_cast<size_t>(second & 0xF0) >> 4;
126         output += base64key[keyIndex];
127 
128         keyIndex = static_cast<size_t>(second & 0x0F) << 2;
129         keyIndex += static_cast<size_t>(third & 0xC0) >> 6;
130         output += base64key[keyIndex];
131 
132         keyIndex = static_cast<size_t>(third & 0x3F);
133         output += base64key[keyIndex];
134     }
135 
136   public:
137     // Accepts a partial string to encode, and writes the encoded characters to
138     // the output stream. requires subsequently calling finalize to complete
139     // stream.
encode(std::string_view data,std::string & output)140     void encode(std::string_view data, std::string& output)
141     {
142         // Encode the last round of overflow chars first
143         if (overflowCount == 2)
144         {
145             if (!data.empty())
146             {
147                 encodeTriple(overflow1, overflow2, data[0], output);
148                 overflowCount = 0;
149                 data.remove_prefix(1);
150             }
151         }
152         else if (overflowCount == 1)
153         {
154             if (data.size() >= 2)
155             {
156                 encodeTriple(overflow1, data[0], data[1], output);
157                 overflowCount = 0;
158                 data.remove_prefix(2);
159             }
160         }
161 
162         while (data.size() >= 3)
163         {
164             encodeTriple(data[0], data[1], data[2], output);
165             data.remove_prefix(3);
166         }
167 
168         if (!data.empty() && overflowCount == 0)
169         {
170             overflow1 = data[0];
171             overflowCount++;
172             data.remove_prefix(1);
173         }
174 
175         if (!data.empty() && overflowCount == 1)
176         {
177             overflow2 = data[0];
178             overflowCount++;
179             data.remove_prefix(1);
180         }
181     }
182 
183     // Completes a base64 output, by writing any MOD(3) characters to the
184     // output, as well as any required trailing =
finalize(std::string & output)185     void finalize(std::string& output)
186     {
187         if (overflowCount == 0)
188         {
189             return;
190         }
191         size_t keyIndex = static_cast<size_t>(overflow1 & 0xFC) >> 2;
192         output += base64key[keyIndex];
193 
194         keyIndex = static_cast<size_t>(overflow1 & 0x03) << 4;
195         if (overflowCount == 2)
196         {
197             keyIndex += static_cast<size_t>(overflow2 & 0xF0) >> 4;
198             output += base64key[keyIndex];
199             keyIndex = static_cast<size_t>(overflow2 & 0x0F) << 2;
200             output += base64key[keyIndex];
201         }
202         else
203         {
204             output += base64key[keyIndex];
205             output += '=';
206         }
207         output += '=';
208         overflowCount = 0;
209     }
210 
211     // Returns the required output buffer in characters for an input of size
212     // inputSize
encodedSize(size_t inputSize)213     static size_t constexpr encodedSize(size_t inputSize)
214     {
215         // Base64 encodes 3 character blocks as 4 character blocks
216         // With a possibility of 2 trailing = characters
217         return (inputSize + 2) / 3 * 4;
218     }
219 };
220 
base64encode(std::string_view data)221 inline std::string base64encode(std::string_view data)
222 {
223     // Encodes a 3 character stream into a 4 character stream
224     std::string out;
225     Base64Encoder base64;
226     out.reserve(Base64Encoder::encodedSize(data.size()));
227     base64.encode(data, out);
228     base64.finalize(out);
229     return out;
230 }
231 
232 template <bool urlsafe = false>
base64Decode(std::string_view input,std::string & output)233 inline bool base64Decode(std::string_view input, std::string& output)
234 {
235     size_t inputLength = input.size();
236 
237     // allocate space for output string
238     output.clear();
239     output.reserve(((inputLength + 2) / 3) * 4);
240 
241     static constexpr auto decodingData = getDecodeTable(urlsafe);
242 
243     auto getCodeValue = [](char c) {
244         auto code = static_cast<unsigned char>(c);
245         // Ensure we cannot index outside the bounds of the decoding array
246         static_assert(
247             std::numeric_limits<decltype(code)>::max() < decodingData.size());
248         return decodingData[code];
249     };
250 
251     // for each 4-bytes sequence from the input, extract 4 6-bits sequences by
252     // dropping first two bits
253     // and regenerate into 3 8-bits sequences
254 
255     for (size_t i = 0; i < inputLength; i++)
256     {
257         char base64code0 = 0;
258         char base64code1 = 0;
259         char base64code2 = 0; // initialized to 0 to suppress warnings
260 
261         base64code0 = getCodeValue(input[i]);
262         if (base64code0 == nop)
263         {
264             // non base64 character
265             return false;
266         }
267         if (!(++i < inputLength))
268         {
269             // we need at least two input bytes for first byte output
270             return false;
271         }
272         base64code1 = getCodeValue(input[i]);
273         if (base64code1 == nop)
274         {
275             // non base64 character
276             return false;
277         }
278         output +=
279             static_cast<char>((base64code0 << 2) | ((base64code1 >> 4) & 0x3));
280 
281         if (++i < inputLength)
282         {
283             char c = input[i];
284             if (c == '=')
285             {
286                 // padding , end of input
287                 return (base64code1 & 0x0f) == 0;
288             }
289             base64code2 = getCodeValue(input[i]);
290             if (base64code2 == nop)
291             {
292                 // non base64 character
293                 return false;
294             }
295             output += static_cast<char>(
296                 ((base64code1 << 4) & 0xf0) | ((base64code2 >> 2) & 0x0f));
297         }
298 
299         if (++i < inputLength)
300         {
301             char c = input[i];
302             if (c == '=')
303             {
304                 // padding , end of input
305                 return (base64code2 & 0x03) == 0;
306             }
307             char base64code3 = getCodeValue(input[i]);
308             if (base64code3 == nop)
309             {
310                 // non base64 character
311                 return false;
312             }
313             output +=
314                 static_cast<char>((((base64code2 << 6) & 0xc0) | base64code3));
315         }
316     }
317 
318     return true;
319 }
320 
321 class OrMorePaths
322 {};
323 
324 template <typename... AV>
appendUrlPieces(boost::urls::url & url,AV &&...args)325 inline void appendUrlPieces(boost::urls::url& url, AV&&... args)
326 {
327     // Unclear the correct fix here.
328     // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay)
329     for (const std::string_view arg : {args...})
330     {
331         url.segments().push_back(arg);
332     }
333 }
334 
335 namespace details
336 {
337 
338 // std::reference_wrapper<std::string> - extracts segment to variable
339 //                    std::string_view - checks if segment is equal to variable
340 using UrlSegment = std::variant<std::reference_wrapper<std::string>,
341                                 std::string_view, OrMorePaths>;
342 
343 enum class UrlParseResult
344 {
345     Continue,
346     Fail,
347     Done,
348 };
349 
350 class UrlSegmentMatcherVisitor
351 {
352   public:
operator ()(std::string & output)353     UrlParseResult operator()(std::string& output)
354     {
355         output = segment;
356         return UrlParseResult::Continue;
357     }
358 
operator ()(std::string_view expected)359     UrlParseResult operator()(std::string_view expected)
360     {
361         if (segment == expected)
362         {
363             return UrlParseResult::Continue;
364         }
365         return UrlParseResult::Fail;
366     }
367 
operator ()(OrMorePaths)368     UrlParseResult operator()(OrMorePaths /*unused*/)
369     {
370         return UrlParseResult::Done;
371     }
372 
UrlSegmentMatcherVisitor(std::string_view segmentIn)373     explicit UrlSegmentMatcherVisitor(std::string_view segmentIn) :
374         segment(segmentIn)
375     {}
376 
377   private:
378     std::string_view segment;
379 };
380 
readUrlSegments(const boost::urls::url_view_base & url,std::initializer_list<UrlSegment> segments)381 inline bool readUrlSegments(const boost::urls::url_view_base& url,
382                             std::initializer_list<UrlSegment> segments)
383 {
384     const boost::urls::segments_view& urlSegments = url.segments();
385 
386     if (!urlSegments.is_absolute())
387     {
388         return false;
389     }
390 
391     boost::urls::segments_view::const_iterator it = urlSegments.begin();
392     boost::urls::segments_view::const_iterator end = urlSegments.end();
393 
394     for (const auto& segment : segments)
395     {
396         if (it == end)
397         {
398             // If the request ends with an "any" path, this was successful
399             return std::holds_alternative<OrMorePaths>(segment);
400         }
401         UrlParseResult res = std::visit(UrlSegmentMatcherVisitor(*it), segment);
402         if (res == UrlParseResult::Done)
403         {
404             return true;
405         }
406         if (res == UrlParseResult::Fail)
407         {
408             return false;
409         }
410         it++;
411     }
412 
413     // There will be an empty segment at the end if the URI ends with a "/"
414     // e.g. /redfish/v1/Chassis/
415     if ((it != end) && urlSegments.back().empty())
416     {
417         it++;
418     }
419     return it == end;
420 }
421 
422 } // namespace details
423 
424 template <typename... Args>
readUrlSegments(const boost::urls::url_view_base & url,Args &&...args)425 inline bool readUrlSegments(const boost::urls::url_view_base& url,
426                             Args&&... args)
427 {
428     return details::readUrlSegments(url, {std::forward<Args>(args)...});
429 }
430 
replaceUrlSegment(const boost::urls::url_view_base & urlView,const uint replaceLoc,std::string_view newSegment)431 inline boost::urls::url replaceUrlSegment(
432     const boost::urls::url_view_base& urlView, const uint replaceLoc,
433     std::string_view newSegment)
434 {
435     const boost::urls::segments_view& urlSegments = urlView.segments();
436     boost::urls::url url("/");
437 
438     if (!urlSegments.is_absolute())
439     {
440         return url;
441     }
442 
443     boost::urls::segments_view::iterator it = urlSegments.begin();
444     boost::urls::segments_view::iterator end = urlSegments.end();
445 
446     for (uint idx = 0; it != end; it++, idx++)
447     {
448         if (idx == replaceLoc)
449         {
450             url.segments().push_back(newSegment);
451         }
452         else
453         {
454             url.segments().push_back(*it);
455         }
456     }
457 
458     return url;
459 }
460 
setProtocolDefaults(boost::urls::url & url,std::string_view protocol)461 inline void setProtocolDefaults(boost::urls::url& url,
462                                 std::string_view protocol)
463 {
464     if (url.has_scheme())
465     {
466         return;
467     }
468     if (protocol == "Redfish" || protocol.empty())
469     {
470         if (url.port_number() == 443)
471         {
472             url.set_scheme("https");
473         }
474         if (url.port_number() == 80)
475         {
476             if constexpr (BMCWEB_INSECURE_PUSH_STYLE_NOTIFICATION)
477             {
478                 url.set_scheme("http");
479             }
480         }
481     }
482     else if (protocol == "SNMPv2c")
483     {
484         url.set_scheme("snmp");
485     }
486 }
487 
setPortDefaults(boost::urls::url & url)488 inline void setPortDefaults(boost::urls::url& url)
489 {
490     uint16_t port = url.port_number();
491     if (port != 0)
492     {
493         return;
494     }
495 
496     // If the user hasn't explicitly stated a port, pick one explicitly for them
497     // based on the protocol defaults
498     if (url.scheme() == "http")
499     {
500         url.set_port_number(80);
501     }
502     if (url.scheme() == "https")
503     {
504         url.set_port_number(443);
505     }
506     if (url.scheme() == "snmp")
507     {
508         url.set_port_number(162);
509     }
510 }
511 
512 } // namespace utility
513 } // namespace crow
514 
515 namespace nlohmann
516 {
517 template <std::derived_from<boost::urls::url_view_base> URL>
518 struct adl_serializer<URL>
519 {
520     // NOLINTNEXTLINE(readability-identifier-naming)
to_jsonnlohmann::adl_serializer521     static void to_json(json& j, const URL& url)
522     {
523         j = url.buffer();
524     }
525 };
526 } // namespace nlohmann
527