xref: /openbmc/bmcweb/http/websocket.hpp (revision 4cee35e7)
1 #pragma once
2 #include "http_request.hpp"
3 
4 #include <async_resp.hpp>
5 #include <boost/algorithm/string/predicate.hpp>
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 struct Connection : std::enable_shared_from_this<Connection>
22 {
23   public:
24     explicit Connection(const crow::Request& reqIn) :
25         req(reqIn.req), userdataPtr(nullptr)
26     {}
27 
28     explicit Connection(const crow::Request& reqIn, std::string user) :
29         req(reqIn.req), userName{std::move(user)}, userdataPtr(nullptr)
30     {}
31 
32     Connection(const Connection&) = delete;
33     Connection(Connection&&) = delete;
34     Connection& operator=(const Connection&) = delete;
35     Connection& operator=(const Connection&&) = delete;
36 
37     virtual void sendBinary(std::string_view msg) = 0;
38     virtual void sendBinary(std::string&& msg) = 0;
39     virtual void sendText(std::string_view msg) = 0;
40     virtual void sendText(std::string&& msg) = 0;
41     virtual void close(std::string_view msg = "quit") = 0;
42     virtual boost::asio::io_context& getIoContext() = 0;
43     virtual ~Connection() = default;
44 
45     void userdata(void* u)
46     {
47         userdataPtr = u;
48     }
49     void* userdata()
50     {
51         return userdataPtr;
52     }
53 
54     const std::string& getUserName() const
55     {
56         return userName;
57     }
58 
59     boost::beast::http::request<boost::beast::http::string_body> req;
60     crow::Response res;
61 
62   private:
63     std::string userName{};
64     void* userdataPtr;
65 };
66 
67 template <typename Adaptor>
68 class ConnectionImpl : public Connection
69 {
70   public:
71     ConnectionImpl(
72         const crow::Request& reqIn, Adaptor adaptorIn,
73         std::function<void(Connection&)> openHandler,
74         std::function<void(Connection&, const std::string&, bool)>
75             messageHandler,
76         std::function<void(Connection&, const std::string&)> closeHandler,
77         std::function<void(Connection&)> errorHandler) :
78         Connection(reqIn, reqIn.session == nullptr ? std::string{}
79                                                    : reqIn.session->username),
80         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
81         openHandler(std::move(openHandler)),
82         messageHandler(std::move(messageHandler)),
83         closeHandler(std::move(closeHandler)),
84         errorHandler(std::move(errorHandler)), session(reqIn.session)
85     {
86         /* Turn on the timeouts on websocket stream to server role */
87         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
88             boost::beast::role_type::server));
89         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
90     }
91 
92     boost::asio::io_context& getIoContext() override
93     {
94         return static_cast<boost::asio::io_context&>(
95             ws.get_executor().context());
96     }
97 
98     void start()
99     {
100         BMCWEB_LOG_DEBUG << "starting connection " << this;
101 
102         using bf = boost::beast::http::field;
103 
104         std::string_view protocol = req[bf::sec_websocket_protocol];
105 
106         ws.set_option(boost::beast::websocket::stream_base::decorator(
107             [session{session}, protocol{std::string(protocol)}](
108                 boost::beast::websocket::response_type& m) {
109 
110 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
111                 if (session != nullptr)
112                 {
113                     // use protocol for csrf checking
114                     if (session->cookieAuth &&
115                         !crow::utility::constantTimeStringCompare(
116                             protocol, session->csrfToken))
117                     {
118                         BMCWEB_LOG_ERROR << "Websocket CSRF error";
119                         m.result(boost::beast::http::status::unauthorized);
120                         return;
121                     }
122                 }
123 #endif
124                 if (!protocol.empty())
125                 {
126                     m.insert(bf::sec_websocket_protocol, protocol);
127                 }
128 
129                 m.insert(bf::strict_transport_security, "max-age=31536000; "
130                                                         "includeSubdomains; "
131                                                         "preload");
132                 m.insert(bf::pragma, "no-cache");
133                 m.insert(bf::cache_control, "no-Store,no-Cache");
134                 m.insert("Content-Security-Policy", "default-src 'self'");
135                 m.insert("X-XSS-Protection", "1; "
136                                              "mode=block");
137                 m.insert("X-Content-Type-Options", "nosniff");
138             }));
139 
140         // Perform the websocket upgrade
141         ws.async_accept(req, [this, self(shared_from_this())](
142                                  boost::system::error_code ec) {
143             if (ec)
144             {
145                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
146                 return;
147             }
148             acceptDone();
149         });
150     }
151 
152     void sendBinary(const std::string_view msg) override
153     {
154         ws.binary(true);
155         outBuffer.emplace_back(msg);
156         doWrite();
157     }
158 
159     void sendBinary(std::string&& msg) override
160     {
161         ws.binary(true);
162         outBuffer.emplace_back(std::move(msg));
163         doWrite();
164     }
165 
166     void sendText(const std::string_view msg) override
167     {
168         ws.text(true);
169         outBuffer.emplace_back(msg);
170         doWrite();
171     }
172 
173     void sendText(std::string&& msg) override
174     {
175         ws.text(true);
176         outBuffer.emplace_back(std::move(msg));
177         doWrite();
178     }
179 
180     void close(const std::string_view msg) override
181     {
182         ws.async_close(
183             {boost::beast::websocket::close_code::normal, msg},
184             [self(shared_from_this())](boost::system::error_code ec) {
185                 if (ec == boost::asio::error::operation_aborted)
186                 {
187                     return;
188                 }
189                 if (ec)
190                 {
191                     BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
192                     return;
193                 }
194             });
195     }
196 
197     void acceptDone()
198     {
199         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
200 
201         doRead();
202 
203         if (openHandler)
204         {
205             openHandler(*this);
206         }
207     }
208 
209     void doRead()
210     {
211         ws.async_read(inBuffer,
212                       [this, self(shared_from_this())](
213                           boost::beast::error_code ec, std::size_t bytesRead) {
214                           if (ec)
215                           {
216                               if (ec != boost::beast::websocket::error::closed)
217                               {
218                                   BMCWEB_LOG_ERROR << "doRead error " << ec;
219                               }
220                               if (closeHandler)
221                               {
222                                   std::string_view reason = ws.reason().reason;
223                                   closeHandler(*this, std::string(reason));
224                               }
225                               return;
226                           }
227                           if (messageHandler)
228                           {
229                               messageHandler(*this, inString, ws.got_text());
230                           }
231                           inBuffer.consume(bytesRead);
232                           inString.clear();
233                           doRead();
234                       });
235     }
236 
237     void doWrite()
238     {
239         // If we're already doing a write, ignore the request, it will be picked
240         // up when the current write is complete
241         if (doingWrite)
242         {
243             return;
244         }
245 
246         if (outBuffer.empty())
247         {
248             // Done for now
249             return;
250         }
251         doingWrite = true;
252         ws.async_write(boost::asio::buffer(outBuffer.front()),
253                        [this, self(shared_from_this())](
254                            boost::beast::error_code ec, std::size_t) {
255                            doingWrite = false;
256                            outBuffer.erase(outBuffer.begin());
257                            if (ec == boost::beast::websocket::error::closed)
258                            {
259                                // Do nothing here.  doRead handler will call the
260                                // closeHandler.
261                                close("Write error");
262                                return;
263                            }
264                            if (ec)
265                            {
266                                BMCWEB_LOG_ERROR << "Error in ws.async_write "
267                                                 << ec;
268                                return;
269                            }
270                            doWrite();
271                        });
272     }
273 
274   private:
275     boost::beast::websocket::stream<Adaptor, false> ws;
276 
277     std::string inString;
278     boost::asio::dynamic_string_buffer<std::string::value_type,
279                                        std::string::traits_type,
280                                        std::string::allocator_type>
281         inBuffer;
282     std::vector<std::string> outBuffer;
283     bool doingWrite = false;
284 
285     std::function<void(Connection&)> openHandler;
286     std::function<void(Connection&, const std::string&, bool)> messageHandler;
287     std::function<void(Connection&, const std::string&)> closeHandler;
288     std::function<void(Connection&)> errorHandler;
289     std::shared_ptr<persistent_data::UserSession> session;
290 };
291 } // namespace websocket
292 } // namespace crow
293