xref: /openbmc/bmcweb/http/utility.hpp (revision 8873f3226c157d26201e60ed9c11f2b3737d8f9b)
1 // SPDX-License-Identifier: Apache-2.0
2 // SPDX-FileCopyrightText: Copyright OpenBMC Authors
3 #pragma once
4 
5 #include "bmcweb_config.h"
6 
7 #include <sys/types.h>
8 
9 #include <boost/url/segments_view.hpp>
10 #include <boost/url/url.hpp>
11 #include <boost/url/url_view_base.hpp>
12 #include <nlohmann/adl_serializer.hpp>
13 #include <nlohmann/json.hpp>
14 
15 #include <array>
16 #include <bit>
17 #include <concepts>
18 #include <cstddef>
19 #include <cstdint>
20 #include <ctime>
21 #include <functional>
22 #include <initializer_list>
23 #include <limits>
24 #include <string>
25 #include <string_view>
26 #include <type_traits>
27 #include <utility>
28 #include <variant>
29 
30 namespace crow
31 {
32 namespace utility
33 {
34 
getParameterTag(std::string_view url)35 constexpr uint64_t getParameterTag(std::string_view url)
36 {
37     uint64_t tagValue = 0;
38     size_t urlSegmentIndex = std::string_view::npos;
39 
40     for (size_t urlIndex = 0; urlIndex < url.size(); urlIndex++)
41     {
42         char character = url[urlIndex];
43         if (character == '<')
44         {
45             if (urlSegmentIndex != std::string_view::npos)
46             {
47                 return 0;
48             }
49             urlSegmentIndex = urlIndex;
50         }
51         if (character == '>')
52         {
53             if (urlSegmentIndex == std::string_view::npos)
54             {
55                 return 0;
56             }
57             std::string_view tag =
58                 url.substr(urlSegmentIndex, urlIndex + 1 - urlSegmentIndex);
59 
60             if (tag == "<str>" || tag == "<string>")
61             {
62                 tagValue++;
63             }
64             if (tag == "<path>")
65             {
66                 tagValue++;
67             }
68             urlSegmentIndex = std::string_view::npos;
69         }
70     }
71     if (urlSegmentIndex != std::string_view::npos)
72     {
73         return 0;
74     }
75     return tagValue;
76 }
77 
78 constexpr static std::array<char, 64> base64key = {
79     'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
80     'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
81     'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
82     'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
83     '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
84 
85 static constexpr char nop = static_cast<char>(-1);
getDecodeTable(bool urlSafe)86 constexpr std::array<char, 256> getDecodeTable(bool urlSafe)
87 {
88     std::array<char, 256> decodeTable{};
89     decodeTable.fill(nop);
90 
91     for (size_t index = 0; index < base64key.size(); index++)
92     {
93         char character = base64key[index];
94         decodeTable[std::bit_cast<uint8_t>(character)] =
95             static_cast<char>(index);
96     }
97 
98     if (urlSafe)
99     {
100         // Urlsafe decode tables replace the last two characters with - and _
101         decodeTable['+'] = nop;
102         decodeTable['/'] = nop;
103         decodeTable['-'] = 62;
104         decodeTable['_'] = 63;
105     }
106 
107     return decodeTable;
108 }
109 
110 class Base64Encoder
111 {
112     char overflow1 = '\0';
113     char overflow2 = '\0';
114     uint8_t overflowCount = 0;
115 
116     // Takes 3 ascii chars, and encodes them as 4 base64 chars
encodeTriple(char first,char second,char third,std::string & output)117     static void encodeTriple(char first, char second, char third,
118                              std::string& output)
119     {
120         size_t keyIndex = 0;
121 
122         keyIndex = static_cast<size_t>(first & 0xFC) >> 2;
123         output += base64key[keyIndex];
124 
125         keyIndex = static_cast<size_t>(first & 0x03) << 4;
126         keyIndex += static_cast<size_t>(second & 0xF0) >> 4;
127         output += base64key[keyIndex];
128 
129         keyIndex = static_cast<size_t>(second & 0x0F) << 2;
130         keyIndex += static_cast<size_t>(third & 0xC0) >> 6;
131         output += base64key[keyIndex];
132 
133         keyIndex = static_cast<size_t>(third & 0x3F);
134         output += base64key[keyIndex];
135     }
136 
137   public:
138     // Accepts a partial string to encode, and writes the encoded characters to
139     // the output stream. requires subsequently calling finalize to complete
140     // stream.
encode(std::string_view data,std::string & output)141     void encode(std::string_view data, std::string& output)
142     {
143         // Encode the last round of overflow chars first
144         if (overflowCount == 2)
145         {
146             if (!data.empty())
147             {
148                 encodeTriple(overflow1, overflow2, data[0], output);
149                 overflowCount = 0;
150                 data.remove_prefix(1);
151             }
152         }
153         else if (overflowCount == 1)
154         {
155             if (data.size() >= 2)
156             {
157                 encodeTriple(overflow1, data[0], data[1], output);
158                 overflowCount = 0;
159                 data.remove_prefix(2);
160             }
161         }
162 
163         while (data.size() >= 3)
164         {
165             encodeTriple(data[0], data[1], data[2], output);
166             data.remove_prefix(3);
167         }
168 
169         if (!data.empty() && overflowCount == 0)
170         {
171             overflow1 = data[0];
172             overflowCount++;
173             data.remove_prefix(1);
174         }
175 
176         if (!data.empty() && overflowCount == 1)
177         {
178             overflow2 = data[0];
179             overflowCount++;
180             data.remove_prefix(1);
181         }
182     }
183 
184     // Completes a base64 output, by writing any MOD(3) characters to the
185     // output, as well as any required trailing =
finalize(std::string & output)186     void finalize(std::string& output)
187     {
188         if (overflowCount == 0)
189         {
190             return;
191         }
192         size_t keyIndex = static_cast<size_t>(overflow1 & 0xFC) >> 2;
193         output += base64key[keyIndex];
194 
195         keyIndex = static_cast<size_t>(overflow1 & 0x03) << 4;
196         if (overflowCount == 2)
197         {
198             keyIndex += static_cast<size_t>(overflow2 & 0xF0) >> 4;
199             output += base64key[keyIndex];
200             keyIndex = static_cast<size_t>(overflow2 & 0x0F) << 2;
201             output += base64key[keyIndex];
202         }
203         else
204         {
205             output += base64key[keyIndex];
206             output += '=';
207         }
208         output += '=';
209         overflowCount = 0;
210     }
211 
212     // Returns the required output buffer in characters for an input of size
213     // inputSize
encodedSize(size_t inputSize)214     static size_t constexpr encodedSize(size_t inputSize)
215     {
216         // Base64 encodes 3 character blocks as 4 character blocks
217         // With a possibility of 2 trailing = characters
218         return (inputSize + 2) / 3 * 4;
219     }
220 };
221 
base64encode(std::string_view data)222 inline std::string base64encode(std::string_view data)
223 {
224     // Encodes a 3 character stream into a 4 character stream
225     std::string out;
226     Base64Encoder base64;
227     out.reserve(Base64Encoder::encodedSize(data.size()));
228     base64.encode(data, out);
229     base64.finalize(out);
230     return out;
231 }
232 
233 template <bool urlsafe = false>
base64Decode(std::string_view input,std::string & output)234 inline bool base64Decode(std::string_view input, std::string& output)
235 {
236     size_t inputLength = input.size();
237 
238     // allocate space for output string
239     output.clear();
240     output.reserve(((inputLength + 2) / 3) * 4);
241 
242     static constexpr auto decodingData = getDecodeTable(urlsafe);
243 
244     auto getCodeValue = [](char c) {
245         auto code = static_cast<unsigned char>(c);
246         // Ensure we cannot index outside the bounds of the decoding array
247         static_assert(
248             std::numeric_limits<decltype(code)>::max() < decodingData.size());
249         return decodingData[code];
250     };
251 
252     // for each 4-bytes sequence from the input, extract 4 6-bits sequences by
253     // dropping first two bits
254     // and regenerate into 3 8-bits sequences
255 
256     for (size_t i = 0; i < inputLength; i++)
257     {
258         char base64code0 = 0;
259         char base64code1 = 0;
260         char base64code2 = 0; // initialized to 0 to suppress warnings
261 
262         base64code0 = getCodeValue(input[i]);
263         if (base64code0 == nop)
264         {
265             // non base64 character
266             return false;
267         }
268         if (!(++i < inputLength))
269         {
270             // we need at least two input bytes for first byte output
271             return false;
272         }
273         base64code1 = getCodeValue(input[i]);
274         if (base64code1 == nop)
275         {
276             // non base64 character
277             return false;
278         }
279         output +=
280             static_cast<char>((base64code0 << 2) | ((base64code1 >> 4) & 0x3));
281 
282         if (++i < inputLength)
283         {
284             char c = input[i];
285             if (c == '=')
286             {
287                 // padding , end of input
288                 return (base64code1 & 0x0f) == 0;
289             }
290             base64code2 = getCodeValue(input[i]);
291             if (base64code2 == nop)
292             {
293                 // non base64 character
294                 return false;
295             }
296             output += static_cast<char>(
297                 ((base64code1 << 4) & 0xf0) | ((base64code2 >> 2) & 0x0f));
298         }
299 
300         if (++i < inputLength)
301         {
302             char c = input[i];
303             if (c == '=')
304             {
305                 // padding , end of input
306                 return (base64code2 & 0x03) == 0;
307             }
308             char base64code3 = getCodeValue(input[i]);
309             if (base64code3 == nop)
310             {
311                 // non base64 character
312                 return false;
313             }
314             output +=
315                 static_cast<char>((((base64code2 << 6) & 0xc0) | base64code3));
316         }
317     }
318 
319     return true;
320 }
321 
322 class OrMorePaths
323 {};
324 
325 template <typename... AV>
appendUrlPieces(boost::urls::url & url,AV &&...args)326 inline void appendUrlPieces(boost::urls::url& url, AV&&... args)
327 {
328     // Unclear the correct fix here.
329     // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay)
330     for (const std::string_view arg : {args...})
331     {
332         url.segments().push_back(arg);
333     }
334 }
335 
336 namespace details
337 {
338 
339 // std::reference_wrapper<std::string> - extracts segment to variable
340 //                    std::string_view - checks if segment is equal to variable
341 using UrlSegment = std::variant<std::reference_wrapper<std::string>,
342                                 std::string_view, OrMorePaths>;
343 
344 enum class UrlParseResult
345 {
346     Continue,
347     Fail,
348     Done,
349 };
350 
351 class UrlSegmentMatcherVisitor
352 {
353   public:
operator ()(std::string & output)354     UrlParseResult operator()(std::string& output)
355     {
356         output = segment;
357         return UrlParseResult::Continue;
358     }
359 
operator ()(std::string_view expected)360     UrlParseResult operator()(std::string_view expected)
361     {
362         if (segment == expected)
363         {
364             return UrlParseResult::Continue;
365         }
366         return UrlParseResult::Fail;
367     }
368 
operator ()(OrMorePaths)369     UrlParseResult operator()(OrMorePaths /*unused*/)
370     {
371         return UrlParseResult::Done;
372     }
373 
UrlSegmentMatcherVisitor(std::string_view segmentIn)374     explicit UrlSegmentMatcherVisitor(std::string_view segmentIn) :
375         segment(segmentIn)
376     {}
377 
378   private:
379     std::string_view segment;
380 };
381 
readUrlSegments(const boost::urls::url_view_base & url,std::initializer_list<UrlSegment> segments)382 inline bool readUrlSegments(const boost::urls::url_view_base& url,
383                             std::initializer_list<UrlSegment> segments)
384 {
385     const boost::urls::segments_view& urlSegments = url.segments();
386 
387     if (!urlSegments.is_absolute())
388     {
389         return false;
390     }
391 
392     boost::urls::segments_view::const_iterator it = urlSegments.begin();
393     boost::urls::segments_view::const_iterator end = urlSegments.end();
394 
395     for (const auto& segment : segments)
396     {
397         if (it == end)
398         {
399             // If the request ends with an "any" path, this was successful
400             return std::holds_alternative<OrMorePaths>(segment);
401         }
402         UrlParseResult res = std::visit(UrlSegmentMatcherVisitor(*it), segment);
403         if (res == UrlParseResult::Done)
404         {
405             return true;
406         }
407         if (res == UrlParseResult::Fail)
408         {
409             return false;
410         }
411         it++;
412     }
413 
414     // There will be an empty segment at the end if the URI ends with a "/"
415     // e.g. /redfish/v1/Chassis/
416     if ((it != end) && urlSegments.back().empty())
417     {
418         it++;
419     }
420     return it == end;
421 }
422 
423 } // namespace details
424 
425 template <typename... Args>
readUrlSegments(const boost::urls::url_view_base & url,Args &&...args)426 inline bool readUrlSegments(const boost::urls::url_view_base& url,
427                             Args&&... args)
428 {
429     return details::readUrlSegments(url, {std::forward<Args>(args)...});
430 }
431 
replaceUrlSegment(const boost::urls::url_view_base & urlView,const uint replaceLoc,std::string_view newSegment)432 inline boost::urls::url replaceUrlSegment(
433     const boost::urls::url_view_base& urlView, const uint replaceLoc,
434     std::string_view newSegment)
435 {
436     const boost::urls::segments_view& urlSegments = urlView.segments();
437     boost::urls::url url("/");
438 
439     if (!urlSegments.is_absolute())
440     {
441         return url;
442     }
443 
444     boost::urls::segments_view::iterator it = urlSegments.begin();
445     boost::urls::segments_view::iterator end = urlSegments.end();
446 
447     for (uint idx = 0; it != end; it++, idx++)
448     {
449         if (idx == replaceLoc)
450         {
451             url.segments().push_back(newSegment);
452         }
453         else
454         {
455             url.segments().push_back(*it);
456         }
457     }
458 
459     return url;
460 }
461 
setProtocolDefaults(boost::urls::url & url,std::string_view protocol)462 inline void setProtocolDefaults(boost::urls::url& url,
463                                 std::string_view protocol)
464 {
465     if (url.has_scheme())
466     {
467         return;
468     }
469     if (protocol == "Redfish" || protocol.empty())
470     {
471         if (url.port_number() == 443)
472         {
473             url.set_scheme("https");
474         }
475         if (url.port_number() == 80)
476         {
477             if constexpr (BMCWEB_INSECURE_PUSH_STYLE_NOTIFICATION)
478             {
479                 url.set_scheme("http");
480             }
481         }
482     }
483     else if (protocol == "SNMPv2c")
484     {
485         url.set_scheme("snmp");
486     }
487 }
488 
setPortDefaults(boost::urls::url & url)489 inline void setPortDefaults(boost::urls::url& url)
490 {
491     uint16_t port = url.port_number();
492     if (port != 0)
493     {
494         return;
495     }
496 
497     // If the user hasn't explicitly stated a port, pick one explicitly for them
498     // based on the protocol defaults
499     if (url.scheme() == "http")
500     {
501         url.set_port_number(80);
502     }
503     if (url.scheme() == "https")
504     {
505         url.set_port_number(443);
506     }
507     if (url.scheme() == "snmp")
508     {
509         url.set_port_number(162);
510     }
511 }
512 
513 } // namespace utility
514 } // namespace crow
515 
516 namespace nlohmann
517 {
518 template <std::derived_from<boost::urls::url_view_base> URL>
519 struct adl_serializer<URL>
520 {
521     // NOLINTNEXTLINE(readability-identifier-naming)
to_jsonnlohmann::adl_serializer522     static void to_json(json& j, const URL& url)
523     {
524         j = url.buffer();
525     }
526 };
527 } // namespace nlohmann
528