xref: /openbmc/bmcweb/http/websocket.hpp (revision 1b8b02a4)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_request.hpp"
4 
5 #include <boost/asio/buffer.hpp>
6 #include <boost/beast/core/multi_buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 enum class MessageType
22 {
23     Binary,
24     Text,
25 };
26 
27 struct Connection : std::enable_shared_from_this<Connection>
28 {
29   public:
30     explicit Connection(const crow::Request& reqIn) : req(reqIn.req) {}
31 
32     Connection(const Connection&) = delete;
33     Connection(Connection&&) = delete;
34     Connection& operator=(const Connection&) = delete;
35     Connection& operator=(const Connection&&) = delete;
36 
37     virtual void sendBinary(std::string_view msg) = 0;
38     virtual void sendBinary(std::string&& msg) = 0;
39     virtual void sendEx(MessageType type, std::string_view msg,
40                         std::function<void()>&& onDone) = 0;
41     virtual void sendText(std::string_view msg) = 0;
42     virtual void sendText(std::string&& msg) = 0;
43     virtual void close(std::string_view msg = "quit") = 0;
44     virtual void deferRead() = 0;
45     virtual void resumeRead() = 0;
46     virtual boost::asio::io_context& getIoContext() = 0;
47     virtual ~Connection() = default;
48     virtual boost::urls::url_view url() = 0;
49     boost::beast::http::request<boost::beast::http::string_body> req;
50 };
51 
52 template <typename Adaptor>
53 class ConnectionImpl : public Connection
54 {
55   public:
56     ConnectionImpl(
57         const crow::Request& reqIn, boost::urls::url_view urlViewIn,
58         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
59         std::function<void(Connection&, const std::string&, bool)>
60             messageHandlerIn,
61         std::function<void(crow::websocket::Connection&, std::string_view,
62                            crow::websocket::MessageType type,
63                            std::function<void()>&& whenComplete)>
64             messageExHandlerIn,
65         std::function<void(Connection&, const std::string&)> closeHandlerIn,
66         std::function<void(Connection&)> errorHandlerIn) :
67         Connection(reqIn),
68         uri(urlViewIn), ws(std::move(adaptorIn)), inBuffer(inString, 131088),
69         openHandler(std::move(openHandlerIn)),
70         messageHandler(std::move(messageHandlerIn)),
71         messageExHandler(std::move(messageExHandlerIn)),
72         closeHandler(std::move(closeHandlerIn)),
73         errorHandler(std::move(errorHandlerIn)), session(reqIn.session)
74     {
75         /* Turn on the timeouts on websocket stream to server role */
76         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
77             boost::beast::role_type::server));
78         BMCWEB_LOG_DEBUG("Creating new connection {}", logPtr(this));
79     }
80 
81     boost::asio::io_context& getIoContext() override
82     {
83         return static_cast<boost::asio::io_context&>(
84             ws.get_executor().context());
85     }
86 
87     void start()
88     {
89         BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
90 
91         using bf = boost::beast::http::field;
92 
93         std::string_view protocol = req[bf::sec_websocket_protocol];
94 
95         ws.set_option(boost::beast::websocket::stream_base::decorator(
96             [session{session}, protocol{std::string(protocol)}](
97                 boost::beast::websocket::response_type& m) {
98 
99 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
100             if (session != nullptr)
101             {
102                 // use protocol for csrf checking
103                 if (session->cookieAuth &&
104                     !crow::utility::constantTimeStringCompare(
105                         protocol, session->csrfToken))
106                 {
107                     BMCWEB_LOG_ERROR("Websocket CSRF error");
108                     m.result(boost::beast::http::status::unauthorized);
109                     return;
110                 }
111             }
112 #endif
113             if (!protocol.empty())
114             {
115                 m.insert(bf::sec_websocket_protocol, protocol);
116             }
117 
118             m.insert(bf::strict_transport_security, "max-age=31536000; "
119                                                     "includeSubdomains; "
120                                                     "preload");
121             m.insert(bf::pragma, "no-cache");
122             m.insert(bf::cache_control, "no-Store,no-Cache");
123             m.insert("Content-Security-Policy", "default-src 'self'");
124             m.insert("X-XSS-Protection", "1; "
125                                          "mode=block");
126             m.insert("X-Content-Type-Options", "nosniff");
127         }));
128 
129         // Perform the websocket upgrade
130         ws.async_accept(req, [this, self(shared_from_this())](
131                                  const boost::system::error_code& ec) {
132             if (ec)
133             {
134                 BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
135                 return;
136             }
137             acceptDone();
138         });
139     }
140 
141     void sendBinary(std::string_view msg) override
142     {
143         ws.binary(true);
144         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
145                                                   boost::asio::buffer(msg)));
146         doWrite();
147     }
148 
149     void sendEx(MessageType type, std::string_view msg,
150                 std::function<void()>&& onDone) override
151     {
152         if (doingWrite)
153         {
154             BMCWEB_LOG_CRITICAL(
155                 "Cannot mix sendEx usage with sendBinary or sendText");
156             onDone();
157             return;
158         }
159         ws.binary(type == MessageType::Binary);
160 
161         ws.async_write(boost::asio::buffer(msg),
162                        [weak(weak_from_this()), onDone{std::move(onDone)}](
163                            const boost::beast::error_code& ec, size_t) {
164             std::shared_ptr<Connection> self = weak.lock();
165 
166             // Call the done handler regardless of whether we
167             // errored, but before we close things out
168             onDone();
169 
170             if (ec)
171             {
172                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
173                 self->close("write error");
174             }
175         });
176     }
177 
178     void sendBinary(std::string&& msg) override
179     {
180         ws.binary(true);
181         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
182                                                   boost::asio::buffer(msg)));
183         doWrite();
184     }
185 
186     void sendText(std::string_view msg) override
187     {
188         ws.text(true);
189         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
190                                                   boost::asio::buffer(msg)));
191         doWrite();
192     }
193 
194     void sendText(std::string&& msg) override
195     {
196         ws.text(true);
197         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
198                                                   boost::asio::buffer(msg)));
199         doWrite();
200     }
201 
202     void close(std::string_view msg) override
203     {
204         ws.async_close(
205             {boost::beast::websocket::close_code::normal, msg},
206             [self(shared_from_this())](const boost::system::error_code& ec) {
207             if (ec == boost::asio::error::operation_aborted)
208             {
209                 return;
210             }
211             if (ec)
212             {
213                 BMCWEB_LOG_ERROR("Error closing websocket {}", ec);
214                 return;
215             }
216             });
217     }
218 
219     boost::urls::url_view url() override
220     {
221         return uri;
222     }
223 
224     void acceptDone()
225     {
226         BMCWEB_LOG_DEBUG("Websocket accepted connection");
227 
228         if (openHandler)
229         {
230             openHandler(*this);
231         }
232         doRead();
233     }
234 
235     void deferRead() override
236     {
237         readingDefered = true;
238 
239         // If we're not actively reading, we need to take ownership of
240         // ourselves for a small portion of time, do that, and clear when we
241         // resume.
242         selfOwned = shared_from_this();
243     }
244 
245     void resumeRead() override
246     {
247         readingDefered = false;
248         doRead();
249 
250         // No longer need to keep ourselves alive now that read is active.
251         selfOwned.reset();
252     }
253 
254     void doRead()
255     {
256         if (readingDefered)
257         {
258             return;
259         }
260         ws.async_read(inBuffer, [this, self(shared_from_this())](
261                                     const boost::beast::error_code& ec,
262                                     size_t bytesRead) {
263             if (ec)
264             {
265                 if (ec != boost::beast::websocket::error::closed)
266                 {
267                     BMCWEB_LOG_ERROR("doRead error {}", ec);
268                 }
269                 if (closeHandler)
270                 {
271                     std::string reason{ws.reason().reason.c_str()};
272                     closeHandler(*this, reason);
273                 }
274                 return;
275             }
276 
277             handleMessage(bytesRead);
278         });
279     }
280     void doWrite()
281     {
282         // If we're already doing a write, ignore the request, it will be picked
283         // up when the current write is complete
284         if (doingWrite)
285         {
286             return;
287         }
288 
289         if (outBuffer.size() == 0)
290         {
291             // Done for now
292             return;
293         }
294         doingWrite = true;
295         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
296                                              const boost::beast::error_code& ec,
297                                              size_t bytesSent) {
298             doingWrite = false;
299             outBuffer.consume(bytesSent);
300             if (ec == boost::beast::websocket::error::closed)
301             {
302                 // Do nothing here.  doRead handler will call the
303                 // closeHandler.
304                 close("Write error");
305                 return;
306             }
307             if (ec)
308             {
309                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
310                 return;
311             }
312             doWrite();
313         });
314     }
315 
316   private:
317     void handleMessage(size_t bytesRead)
318     {
319         if (messageExHandler)
320         {
321             // Note, because of the interactions with the read buffers,
322             // this message handler overrides the normal message handler
323             messageExHandler(*this, inString, MessageType::Binary,
324                              [this, self(shared_from_this()), bytesRead]() {
325                 if (self == nullptr)
326                 {
327                     return;
328                 }
329 
330                 inBuffer.consume(bytesRead);
331                 inString.clear();
332 
333                 doRead();
334             });
335             return;
336         }
337 
338         if (messageHandler)
339         {
340             messageHandler(*this, inString, ws.got_text());
341         }
342         inBuffer.consume(bytesRead);
343         inString.clear();
344         doRead();
345     }
346 
347     boost::urls::url uri;
348 
349     boost::beast::websocket::stream<Adaptor, false> ws;
350 
351     bool readingDefered = false;
352     std::string inString;
353     boost::asio::dynamic_string_buffer<std::string::value_type,
354                                        std::string::traits_type,
355                                        std::string::allocator_type>
356         inBuffer;
357 
358     boost::beast::multi_buffer outBuffer;
359     bool doingWrite = false;
360 
361     std::function<void(Connection&)> openHandler;
362     std::function<void(Connection&, const std::string&, bool)> messageHandler;
363     std::function<void(crow::websocket::Connection&, std::string_view,
364                        crow::websocket::MessageType type,
365                        std::function<void()>&& whenComplete)>
366         messageExHandler;
367     std::function<void(Connection&, const std::string&)> closeHandler;
368     std::function<void(Connection&)> errorHandler;
369     std::shared_ptr<persistent_data::UserSession> session;
370 
371     std::shared_ptr<Connection> selfOwned;
372 };
373 } // namespace websocket
374 } // namespace crow
375