xref: /openbmc/bmcweb/http/websocket.hpp (revision f263e09c)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_request.hpp"
4 
5 #include <boost/asio/buffer.hpp>
6 #include <boost/beast/core/multi_buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 enum class MessageType
22 {
23     Binary,
24     Text,
25 };
26 
27 struct Connection : std::enable_shared_from_this<Connection>
28 {
29   public:
30     explicit Connection(const crow::Request& reqIn) : req(reqIn.req) {}
31 
32     Connection(const Connection&) = delete;
33     Connection(Connection&&) = delete;
34     Connection& operator=(const Connection&) = delete;
35     Connection& operator=(const Connection&&) = delete;
36 
37     virtual void sendBinary(std::string_view msg) = 0;
38     virtual void sendBinary(std::string&& msg) = 0;
39     virtual void sendEx(MessageType type, std::string_view msg,
40                         std::function<void()>&& onDone) = 0;
41     virtual void sendText(std::string_view msg) = 0;
42     virtual void sendText(std::string&& msg) = 0;
43     virtual void close(std::string_view msg = "quit") = 0;
44     virtual void deferRead() = 0;
45     virtual void resumeRead() = 0;
46     virtual boost::asio::io_context& getIoContext() = 0;
47     virtual ~Connection() = default;
48 
49     boost::beast::http::request<boost::beast::http::string_body> req;
50 };
51 
52 template <typename Adaptor>
53 class ConnectionImpl : public Connection
54 {
55   public:
56     ConnectionImpl(
57         const crow::Request& reqIn, Adaptor adaptorIn,
58         std::function<void(Connection&)> openHandlerIn,
59         std::function<void(Connection&, const std::string&, bool)>
60             messageHandlerIn,
61         std::function<void(crow::websocket::Connection&, std::string_view,
62                            crow::websocket::MessageType type,
63                            std::function<void()>&& whenComplete)>
64             messageExHandlerIn,
65         std::function<void(Connection&, const std::string&)> closeHandlerIn,
66         std::function<void(Connection&)> errorHandlerIn) :
67         Connection(reqIn),
68         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
69         openHandler(std::move(openHandlerIn)),
70         messageHandler(std::move(messageHandlerIn)),
71         messageExHandler(std::move(messageExHandlerIn)),
72         closeHandler(std::move(closeHandlerIn)),
73         errorHandler(std::move(errorHandlerIn)), session(reqIn.session)
74     {
75         /* Turn on the timeouts on websocket stream to server role */
76         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
77             boost::beast::role_type::server));
78         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
79     }
80 
81     boost::asio::io_context& getIoContext() override
82     {
83         return static_cast<boost::asio::io_context&>(
84             ws.get_executor().context());
85     }
86 
87     void start()
88     {
89         BMCWEB_LOG_DEBUG << "starting connection " << this;
90 
91         using bf = boost::beast::http::field;
92 
93         std::string_view protocol = req[bf::sec_websocket_protocol];
94 
95         ws.set_option(boost::beast::websocket::stream_base::decorator(
96             [session{session}, protocol{std::string(protocol)}](
97                 boost::beast::websocket::response_type& m) {
98 
99 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
100             if (session != nullptr)
101             {
102                 // use protocol for csrf checking
103                 if (!crow::utility::constantTimeStringCompare(
104                         protocol, session->csrfToken))
105                 {
106                     BMCWEB_LOG_ERROR << "Websocket CSRF error";
107                     m.result(boost::beast::http::status::unauthorized);
108                     return;
109                 }
110             }
111 #endif
112             if (!protocol.empty())
113             {
114                 m.insert(bf::sec_websocket_protocol, protocol);
115             }
116 
117             m.insert(bf::strict_transport_security, "max-age=31536000; "
118                                                     "includeSubdomains; "
119                                                     "preload");
120             m.insert(bf::pragma, "no-cache");
121             m.insert(bf::cache_control, "no-Store,no-Cache");
122             m.insert("Content-Security-Policy", "default-src 'self'");
123             m.insert("X-XSS-Protection", "1; "
124                                          "mode=block");
125             m.insert("X-Content-Type-Options", "nosniff");
126         }));
127 
128         // Perform the websocket upgrade
129         ws.async_accept(req, [this, self(shared_from_this())](
130                                  const boost::system::error_code& ec) {
131             if (ec)
132             {
133                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
134                 return;
135             }
136             acceptDone();
137         });
138     }
139 
140     void sendBinary(std::string_view msg) override
141     {
142         ws.binary(true);
143         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
144                                                   boost::asio::buffer(msg)));
145         doWrite();
146     }
147 
148     void sendEx(MessageType type, std::string_view msg,
149                 std::function<void()>&& onDone) override
150     {
151         if (doingWrite)
152         {
153             BMCWEB_LOG_CRITICAL
154                 << "Cannot mix sendEx usage with sendBinary or sendText";
155             onDone();
156             return;
157         }
158         ws.binary(type == MessageType::Binary);
159 
160         ws.async_write(boost::asio::buffer(msg),
161                        [weak(weak_from_this()), onDone{std::move(onDone)}](
162                            const boost::beast::error_code& ec, size_t) {
163             std::shared_ptr<Connection> self = weak.lock();
164 
165             // Call the done handler regardless of whether we
166             // errored, but before we close things out
167             onDone();
168 
169             if (ec)
170             {
171                 BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
172                 self->close("write error");
173             }
174         });
175     }
176 
177     void sendBinary(std::string&& msg) override
178     {
179         ws.binary(true);
180         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
181                                                   boost::asio::buffer(msg)));
182         doWrite();
183     }
184 
185     void sendText(std::string_view msg) override
186     {
187         ws.text(true);
188         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
189                                                   boost::asio::buffer(msg)));
190         doWrite();
191     }
192 
193     void sendText(std::string&& msg) override
194     {
195         ws.text(true);
196         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
197                                                   boost::asio::buffer(msg)));
198         doWrite();
199     }
200 
201     void close(std::string_view msg) override
202     {
203         ws.async_close(
204             {boost::beast::websocket::close_code::normal, msg},
205             [self(shared_from_this())](const boost::system::error_code& ec) {
206             if (ec == boost::asio::error::operation_aborted)
207             {
208                 return;
209             }
210             if (ec)
211             {
212                 BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
213                 return;
214             }
215             });
216     }
217 
218     void acceptDone()
219     {
220         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
221 
222         if (openHandler)
223         {
224             openHandler(*this);
225         }
226         doRead();
227     }
228 
229     void deferRead() override
230     {
231         readingDefered = true;
232 
233         // If we're not actively reading, we need to take ownership of
234         // ourselves for a small portion of time, do that, and clear when we
235         // resume.
236         selfOwned = shared_from_this();
237     }
238 
239     void resumeRead() override
240     {
241         readingDefered = false;
242         doRead();
243 
244         // No longer need to keep ourselves alive now that read is active.
245         selfOwned.reset();
246     }
247 
248     void doRead()
249     {
250         if (readingDefered)
251         {
252             return;
253         }
254         ws.async_read(inBuffer, [this, self(shared_from_this())](
255                                     const boost::beast::error_code& ec,
256                                     size_t bytesRead) {
257             if (ec)
258             {
259                 if (ec != boost::beast::websocket::error::closed)
260                 {
261                     BMCWEB_LOG_ERROR << "doRead error " << ec;
262                 }
263                 if (closeHandler)
264                 {
265                     std::string reason{ws.reason().reason.c_str()};
266                     closeHandler(*this, reason);
267                 }
268                 return;
269             }
270 
271             handleMessage(bytesRead);
272         });
273     }
274     void doWrite()
275     {
276         // If we're already doing a write, ignore the request, it will be picked
277         // up when the current write is complete
278         if (doingWrite)
279         {
280             return;
281         }
282 
283         if (outBuffer.size() == 0)
284         {
285             // Done for now
286             return;
287         }
288         doingWrite = true;
289         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
290                                              const boost::beast::error_code& ec,
291                                              size_t bytesSent) {
292             doingWrite = false;
293             outBuffer.consume(bytesSent);
294             if (ec == boost::beast::websocket::error::closed)
295             {
296                 // Do nothing here.  doRead handler will call the
297                 // closeHandler.
298                 close("Write error");
299                 return;
300             }
301             if (ec)
302             {
303                 BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
304                 return;
305             }
306             doWrite();
307         });
308     }
309 
310   private:
311     void handleMessage(size_t bytesRead)
312     {
313         if (messageExHandler)
314         {
315             // Note, because of the interactions with the read buffers,
316             // this message handler overrides the normal message handler
317             messageExHandler(*this, inString, MessageType::Binary,
318                              [this, self(shared_from_this()), bytesRead]() {
319                 if (self == nullptr)
320                 {
321                     return;
322                 }
323 
324                 inBuffer.consume(bytesRead);
325                 inString.clear();
326 
327                 doRead();
328             });
329             return;
330         }
331 
332         if (messageHandler)
333         {
334             messageHandler(*this, inString, ws.got_text());
335         }
336         inBuffer.consume(bytesRead);
337         inString.clear();
338         doRead();
339     }
340 
341     boost::beast::websocket::stream<Adaptor, false> ws;
342 
343     bool readingDefered = false;
344     std::string inString;
345     boost::asio::dynamic_string_buffer<std::string::value_type,
346                                        std::string::traits_type,
347                                        std::string::allocator_type>
348         inBuffer;
349 
350     boost::beast::multi_buffer outBuffer;
351     bool doingWrite = false;
352 
353     std::function<void(Connection&)> openHandler;
354     std::function<void(Connection&, const std::string&, bool)> messageHandler;
355     std::function<void(crow::websocket::Connection&, std::string_view,
356                        crow::websocket::MessageType type,
357                        std::function<void()>&& whenComplete)>
358         messageExHandler;
359     std::function<void(Connection&, const std::string&)> closeHandler;
360     std::function<void(Connection&)> errorHandler;
361     std::shared_ptr<persistent_data::UserSession> session;
362 
363     std::shared_ptr<Connection> selfOwned;
364 };
365 } // namespace websocket
366 } // namespace crow
367