xref: /openbmc/bmcweb/http/websocket.hpp (revision 052bcbf48802da1fa9583c8c0990378304e29903)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_request.hpp"
4 
5 #include <boost/asio/buffer.hpp>
6 #include <boost/beast/core/multi_buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 enum class MessageType
22 {
23     Binary,
24     Text,
25 };
26 
27 struct Connection : std::enable_shared_from_this<Connection>
28 {
29   public:
30     explicit Connection(const crow::Request& reqIn) : req(reqIn.req) {}
31 
32     Connection(const Connection&) = delete;
33     Connection(Connection&&) = delete;
34     Connection& operator=(const Connection&) = delete;
35     Connection& operator=(const Connection&&) = delete;
36 
37     virtual void sendBinary(std::string_view msg) = 0;
38     virtual void sendBinary(std::string&& msg) = 0;
39     virtual void sendEx(MessageType type, std::string_view msg,
40                         std::function<void()>&& onDone) = 0;
41     virtual void sendText(std::string_view msg) = 0;
42     virtual void sendText(std::string&& msg) = 0;
43     virtual void close(std::string_view msg = "quit") = 0;
44     virtual void deferRead() = 0;
45     virtual void resumeRead() = 0;
46     virtual boost::asio::io_context& getIoContext() = 0;
47     virtual ~Connection() = default;
48     virtual boost::urls::url_view url() = 0;
49     boost::beast::http::request<boost::beast::http::string_body> req;
50 };
51 
52 template <typename Adaptor>
53 class ConnectionImpl : public Connection
54 {
55   public:
56     ConnectionImpl(
57         const crow::Request& reqIn, boost::urls::url_view urlViewIn,
58         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
59         std::function<void(Connection&, const std::string&, bool)>
60             messageHandlerIn,
61         std::function<void(crow::websocket::Connection&, std::string_view,
62                            crow::websocket::MessageType type,
63                            std::function<void()>&& whenComplete)>
64             messageExHandlerIn,
65         std::function<void(Connection&, const std::string&)> closeHandlerIn,
66         std::function<void(Connection&)> errorHandlerIn) :
67         Connection(reqIn),
68         uri(urlViewIn), ws(std::move(adaptorIn)), inBuffer(inString, 131088),
69         openHandler(std::move(openHandlerIn)),
70         messageHandler(std::move(messageHandlerIn)),
71         messageExHandler(std::move(messageExHandlerIn)),
72         closeHandler(std::move(closeHandlerIn)),
73         errorHandler(std::move(errorHandlerIn)), session(reqIn.session)
74     {
75         /* Turn on the timeouts on websocket stream to server role */
76         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
77             boost::beast::role_type::server));
78         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
79     }
80 
81     boost::asio::io_context& getIoContext() override
82     {
83         return static_cast<boost::asio::io_context&>(
84             ws.get_executor().context());
85     }
86 
87     void start()
88     {
89         BMCWEB_LOG_DEBUG << "starting connection " << this;
90 
91         using bf = boost::beast::http::field;
92 
93         std::string_view protocol = req[bf::sec_websocket_protocol];
94 
95         ws.set_option(boost::beast::websocket::stream_base::decorator(
96             [session{session}, protocol{std::string(protocol)}](
97                 boost::beast::websocket::response_type& m) {
98 
99 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
100             if (session != nullptr)
101             {
102                 // use protocol for csrf checking
103                 if (!crow::utility::constantTimeStringCompare(
104                         protocol, session->csrfToken))
105                 {
106                     BMCWEB_LOG_ERROR << "Websocket CSRF error";
107                     m.result(boost::beast::http::status::unauthorized);
108                     return;
109                 }
110             }
111 #endif
112             if (!protocol.empty())
113             {
114                 m.insert(bf::sec_websocket_protocol, protocol);
115             }
116 
117             m.insert(bf::strict_transport_security, "max-age=31536000; "
118                                                     "includeSubdomains; "
119                                                     "preload");
120             m.insert(bf::pragma, "no-cache");
121             m.insert(bf::cache_control, "no-Store,no-Cache");
122             m.insert("Content-Security-Policy", "default-src 'self'");
123             m.insert("X-XSS-Protection", "1; "
124                                          "mode=block");
125             m.insert("X-Content-Type-Options", "nosniff");
126         }));
127 
128         // Perform the websocket upgrade
129         ws.async_accept(req, [this, self(shared_from_this())](
130                                  const boost::system::error_code& ec) {
131             if (ec)
132             {
133                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
134                 return;
135             }
136             acceptDone();
137         });
138     }
139 
140     void sendBinary(std::string_view msg) override
141     {
142         ws.binary(true);
143         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
144                                                   boost::asio::buffer(msg)));
145         doWrite();
146     }
147 
148     void sendEx(MessageType type, std::string_view msg,
149                 std::function<void()>&& onDone) override
150     {
151         if (doingWrite)
152         {
153             BMCWEB_LOG_CRITICAL
154                 << "Cannot mix sendEx usage with sendBinary or sendText";
155             onDone();
156             return;
157         }
158         ws.binary(type == MessageType::Binary);
159 
160         ws.async_write(boost::asio::buffer(msg),
161                        [weak(weak_from_this()), onDone{std::move(onDone)}](
162                            const boost::beast::error_code& ec, size_t) {
163             std::shared_ptr<Connection> self = weak.lock();
164 
165             // Call the done handler regardless of whether we
166             // errored, but before we close things out
167             onDone();
168 
169             if (ec)
170             {
171                 BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
172                 self->close("write error");
173             }
174         });
175     }
176 
177     void sendBinary(std::string&& msg) override
178     {
179         ws.binary(true);
180         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
181                                                   boost::asio::buffer(msg)));
182         doWrite();
183     }
184 
185     void sendText(std::string_view msg) override
186     {
187         ws.text(true);
188         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
189                                                   boost::asio::buffer(msg)));
190         doWrite();
191     }
192 
193     void sendText(std::string&& msg) override
194     {
195         ws.text(true);
196         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
197                                                   boost::asio::buffer(msg)));
198         doWrite();
199     }
200 
201     void close(std::string_view msg) override
202     {
203         ws.async_close(
204             {boost::beast::websocket::close_code::normal, msg},
205             [self(shared_from_this())](const boost::system::error_code& ec) {
206             if (ec == boost::asio::error::operation_aborted)
207             {
208                 return;
209             }
210             if (ec)
211             {
212                 BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
213                 return;
214             }
215             });
216     }
217 
218     boost::urls::url_view url() override
219     {
220         return uri;
221     }
222 
223     void acceptDone()
224     {
225         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
226 
227         if (openHandler)
228         {
229             openHandler(*this);
230         }
231         doRead();
232     }
233 
234     void deferRead() override
235     {
236         readingDefered = true;
237 
238         // If we're not actively reading, we need to take ownership of
239         // ourselves for a small portion of time, do that, and clear when we
240         // resume.
241         selfOwned = shared_from_this();
242     }
243 
244     void resumeRead() override
245     {
246         readingDefered = false;
247         doRead();
248 
249         // No longer need to keep ourselves alive now that read is active.
250         selfOwned.reset();
251     }
252 
253     void doRead()
254     {
255         if (readingDefered)
256         {
257             return;
258         }
259         ws.async_read(inBuffer, [this, self(shared_from_this())](
260                                     const boost::beast::error_code& ec,
261                                     size_t bytesRead) {
262             if (ec)
263             {
264                 if (ec != boost::beast::websocket::error::closed)
265                 {
266                     BMCWEB_LOG_ERROR << "doRead error " << ec;
267                 }
268                 if (closeHandler)
269                 {
270                     std::string reason{ws.reason().reason.c_str()};
271                     closeHandler(*this, reason);
272                 }
273                 return;
274             }
275 
276             handleMessage(bytesRead);
277         });
278     }
279     void doWrite()
280     {
281         // If we're already doing a write, ignore the request, it will be picked
282         // up when the current write is complete
283         if (doingWrite)
284         {
285             return;
286         }
287 
288         if (outBuffer.size() == 0)
289         {
290             // Done for now
291             return;
292         }
293         doingWrite = true;
294         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
295                                              const boost::beast::error_code& ec,
296                                              size_t bytesSent) {
297             doingWrite = false;
298             outBuffer.consume(bytesSent);
299             if (ec == boost::beast::websocket::error::closed)
300             {
301                 // Do nothing here.  doRead handler will call the
302                 // closeHandler.
303                 close("Write error");
304                 return;
305             }
306             if (ec)
307             {
308                 BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
309                 return;
310             }
311             doWrite();
312         });
313     }
314 
315   private:
316     void handleMessage(size_t bytesRead)
317     {
318         if (messageExHandler)
319         {
320             // Note, because of the interactions with the read buffers,
321             // this message handler overrides the normal message handler
322             messageExHandler(*this, inString, MessageType::Binary,
323                              [this, self(shared_from_this()), bytesRead]() {
324                 if (self == nullptr)
325                 {
326                     return;
327                 }
328 
329                 inBuffer.consume(bytesRead);
330                 inString.clear();
331 
332                 doRead();
333             });
334             return;
335         }
336 
337         if (messageHandler)
338         {
339             messageHandler(*this, inString, ws.got_text());
340         }
341         inBuffer.consume(bytesRead);
342         inString.clear();
343         doRead();
344     }
345 
346     boost::urls::url uri;
347 
348     boost::beast::websocket::stream<Adaptor, false> ws;
349 
350     bool readingDefered = false;
351     std::string inString;
352     boost::asio::dynamic_string_buffer<std::string::value_type,
353                                        std::string::traits_type,
354                                        std::string::allocator_type>
355         inBuffer;
356 
357     boost::beast::multi_buffer outBuffer;
358     bool doingWrite = false;
359 
360     std::function<void(Connection&)> openHandler;
361     std::function<void(Connection&, const std::string&, bool)> messageHandler;
362     std::function<void(crow::websocket::Connection&, std::string_view,
363                        crow::websocket::MessageType type,
364                        std::function<void()>&& whenComplete)>
365         messageExHandler;
366     std::function<void(Connection&, const std::string&)> closeHandler;
367     std::function<void(Connection&)> errorHandler;
368     std::shared_ptr<persistent_data::UserSession> session;
369 
370     std::shared_ptr<Connection> selfOwned;
371 };
372 } // namespace websocket
373 } // namespace crow
374