xref: /openbmc/bmcweb/http/websocket.hpp (revision 2ae81db9)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_request.hpp"
4 
5 #include <boost/asio/buffer.hpp>
6 #include <boost/beast/core/multi_buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 enum class MessageType
22 {
23     Binary,
24     Text,
25 };
26 
27 struct Connection : std::enable_shared_from_this<Connection>
28 {
29   public:
30     Connection() = default;
31 
32     Connection(const Connection&) = delete;
33     Connection(Connection&&) = delete;
34     Connection& operator=(const Connection&) = delete;
35     Connection& operator=(const Connection&&) = delete;
36 
37     virtual void sendBinary(std::string_view msg) = 0;
38     virtual void sendBinary(std::string&& msg) = 0;
39     virtual void sendEx(MessageType type, std::string_view msg,
40                         std::function<void()>&& onDone) = 0;
41     virtual void sendText(std::string_view msg) = 0;
42     virtual void sendText(std::string&& msg) = 0;
43     virtual void close(std::string_view msg = "quit") = 0;
44     virtual void deferRead() = 0;
45     virtual void resumeRead() = 0;
46     virtual boost::asio::io_context& getIoContext() = 0;
47     virtual ~Connection() = default;
48     virtual boost::urls::url_view url() = 0;
49 };
50 
51 template <typename Adaptor>
52 class ConnectionImpl : public Connection
53 {
54     using self_t = ConnectionImpl<Adaptor>;
55 
56   public:
57     ConnectionImpl(
58         const boost::urls::url_view& urlViewIn,
59         const std::shared_ptr<persistent_data::UserSession>& sessionIn,
60         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
61         std::function<void(Connection&, const std::string&, bool)>
62             messageHandlerIn,
63         std::function<void(crow::websocket::Connection&, std::string_view,
64                            crow::websocket::MessageType type,
65                            std::function<void()>&& whenComplete)>
66             messageExHandlerIn,
67         std::function<void(Connection&, const std::string&)> closeHandlerIn,
68         std::function<void(Connection&)> errorHandlerIn) :
69         uri(urlViewIn),
70         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
71         openHandler(std::move(openHandlerIn)),
72         messageHandler(std::move(messageHandlerIn)),
73         messageExHandler(std::move(messageExHandlerIn)),
74         closeHandler(std::move(closeHandlerIn)),
75         errorHandler(std::move(errorHandlerIn)), session(sessionIn)
76     {
77         /* Turn on the timeouts on websocket stream to server role */
78         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
79             boost::beast::role_type::server));
80         BMCWEB_LOG_DEBUG("Creating new connection {}", logPtr(this));
81     }
82 
83     boost::asio::io_context& getIoContext() override
84     {
85         return static_cast<boost::asio::io_context&>(
86             ws.get_executor().context());
87     }
88 
89     void start(const crow::Request& req)
90     {
91         BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
92 
93         using bf = boost::beast::http::field;
94         std::string protocolHeader = req.req[bf::sec_websocket_protocol];
95 
96         ws.set_option(boost::beast::websocket::stream_base::decorator(
97             [session{session},
98              protocolHeader](boost::beast::websocket::response_type& m) {
99 
100 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
101             if (session != nullptr)
102             {
103                 // use protocol for csrf checking
104                 if (session->cookieAuth &&
105                     !crow::utility::constantTimeStringCompare(
106                         protocolHeader, session->csrfToken))
107                 {
108                     BMCWEB_LOG_ERROR("Websocket CSRF error");
109                     m.result(boost::beast::http::status::unauthorized);
110                     return;
111                 }
112             }
113 #endif
114             if (!protocolHeader.empty())
115             {
116                 m.insert(bf::sec_websocket_protocol, protocolHeader);
117             }
118 
119             m.insert(bf::strict_transport_security, "max-age=31536000; "
120                                                     "includeSubdomains; "
121                                                     "preload");
122             m.insert(bf::pragma, "no-cache");
123             m.insert(bf::cache_control, "no-Store,no-Cache");
124             m.insert("Content-Security-Policy", "default-src 'self'");
125             m.insert("X-XSS-Protection", "1; "
126                                          "mode=block");
127             m.insert("X-Content-Type-Options", "nosniff");
128         }));
129 
130         // Make a pointer to keep the req alive while we accept it.
131         using Body =
132             boost::beast::http::request<boost::beast::http::string_body>;
133         std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req);
134         Body* ptr = mobile.get();
135         // Perform the websocket upgrade
136         ws.async_accept(*ptr,
137                         std::bind_front(&self_t::acceptDone, this,
138                                         shared_from_this(), std::move(mobile)));
139     }
140 
141     void sendBinary(std::string_view msg) override
142     {
143         ws.binary(true);
144         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
145                                                   boost::asio::buffer(msg)));
146         doWrite();
147     }
148 
149     void sendEx(MessageType type, std::string_view msg,
150                 std::function<void()>&& onDone) override
151     {
152         if (doingWrite)
153         {
154             BMCWEB_LOG_CRITICAL(
155                 "Cannot mix sendEx usage with sendBinary or sendText");
156             onDone();
157             return;
158         }
159         ws.binary(type == MessageType::Binary);
160 
161         ws.async_write(boost::asio::buffer(msg),
162                        [weak(weak_from_this()), onDone{std::move(onDone)}](
163                            const boost::beast::error_code& ec, size_t) {
164             std::shared_ptr<Connection> self = weak.lock();
165             if (!self)
166             {
167                 BMCWEB_LOG_ERROR("Connection went away");
168                 return;
169             }
170 
171             // Call the done handler regardless of whether we
172             // errored, but before we close things out
173             onDone();
174 
175             if (ec)
176             {
177                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
178                 self->close("write error");
179             }
180         });
181     }
182 
183     void sendBinary(std::string&& msg) override
184     {
185         ws.binary(true);
186         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
187                                                   boost::asio::buffer(msg)));
188         doWrite();
189     }
190 
191     void sendText(std::string_view msg) override
192     {
193         ws.text(true);
194         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
195                                                   boost::asio::buffer(msg)));
196         doWrite();
197     }
198 
199     void sendText(std::string&& msg) override
200     {
201         ws.text(true);
202         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
203                                                   boost::asio::buffer(msg)));
204         doWrite();
205     }
206 
207     void close(std::string_view msg) override
208     {
209         ws.async_close(
210             {boost::beast::websocket::close_code::normal, msg},
211             [self(shared_from_this())](const boost::system::error_code& ec) {
212             if (ec == boost::asio::error::operation_aborted)
213             {
214                 return;
215             }
216             if (ec)
217             {
218                 BMCWEB_LOG_ERROR("Error closing websocket {}", ec);
219                 return;
220             }
221         });
222     }
223 
224     boost::urls::url_view url() override
225     {
226         return uri;
227     }
228 
229     void acceptDone(const std::shared_ptr<Connection>& /*self*/,
230                     const std::unique_ptr<boost::beast::http::request<
231                         boost::beast::http::string_body>>& /*req*/,
232                     const boost::system::error_code& ec)
233     {
234         if (ec)
235         {
236             BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
237             return;
238         }
239         BMCWEB_LOG_DEBUG("Websocket accepted connection");
240 
241         if (openHandler)
242         {
243             openHandler(*this);
244         }
245         doRead();
246     }
247 
248     void deferRead() override
249     {
250         readingDefered = true;
251 
252         // If we're not actively reading, we need to take ownership of
253         // ourselves for a small portion of time, do that, and clear when we
254         // resume.
255         selfOwned = shared_from_this();
256     }
257 
258     void resumeRead() override
259     {
260         readingDefered = false;
261         doRead();
262 
263         // No longer need to keep ourselves alive now that read is active.
264         selfOwned.reset();
265     }
266 
267     void doRead()
268     {
269         if (readingDefered)
270         {
271             return;
272         }
273         ws.async_read(inBuffer, [this, self(shared_from_this())](
274                                     const boost::beast::error_code& ec,
275                                     size_t bytesRead) {
276             if (ec)
277             {
278                 if (ec != boost::beast::websocket::error::closed)
279                 {
280                     BMCWEB_LOG_ERROR("doRead error {}", ec);
281                 }
282                 if (closeHandler)
283                 {
284                     std::string reason{ws.reason().reason.c_str()};
285                     closeHandler(*this, reason);
286                 }
287                 return;
288             }
289 
290             handleMessage(bytesRead);
291         });
292     }
293     void doWrite()
294     {
295         // If we're already doing a write, ignore the request, it will be picked
296         // up when the current write is complete
297         if (doingWrite)
298         {
299             return;
300         }
301 
302         if (outBuffer.size() == 0)
303         {
304             // Done for now
305             return;
306         }
307         doingWrite = true;
308         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
309                                              const boost::beast::error_code& ec,
310                                              size_t bytesSent) {
311             doingWrite = false;
312             outBuffer.consume(bytesSent);
313             if (ec == boost::beast::websocket::error::closed)
314             {
315                 // Do nothing here.  doRead handler will call the
316                 // closeHandler.
317                 close("Write error");
318                 return;
319             }
320             if (ec)
321             {
322                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
323                 return;
324             }
325             doWrite();
326         });
327     }
328 
329   private:
330     void handleMessage(size_t bytesRead)
331     {
332         if (messageExHandler)
333         {
334             // Note, because of the interactions with the read buffers,
335             // this message handler overrides the normal message handler
336             messageExHandler(*this, inString, MessageType::Binary,
337                              [this, self(shared_from_this()), bytesRead]() {
338                 if (self == nullptr)
339                 {
340                     return;
341                 }
342 
343                 inBuffer.consume(bytesRead);
344                 inString.clear();
345 
346                 doRead();
347             });
348             return;
349         }
350 
351         if (messageHandler)
352         {
353             messageHandler(*this, inString, ws.got_text());
354         }
355         inBuffer.consume(bytesRead);
356         inString.clear();
357         doRead();
358     }
359 
360     boost::urls::url uri;
361 
362     boost::beast::websocket::stream<Adaptor, false> ws;
363 
364     bool readingDefered = false;
365     std::string inString;
366     boost::asio::dynamic_string_buffer<std::string::value_type,
367                                        std::string::traits_type,
368                                        std::string::allocator_type>
369         inBuffer;
370 
371     boost::beast::multi_buffer outBuffer;
372     bool doingWrite = false;
373 
374     std::function<void(Connection&)> openHandler;
375     std::function<void(Connection&, const std::string&, bool)> messageHandler;
376     std::function<void(crow::websocket::Connection&, std::string_view,
377                        crow::websocket::MessageType type,
378                        std::function<void()>&& whenComplete)>
379         messageExHandler;
380     std::function<void(Connection&, const std::string&)> closeHandler;
381     std::function<void(Connection&)> errorHandler;
382     std::shared_ptr<persistent_data::UserSession> session;
383 
384     std::shared_ptr<Connection> selfOwned;
385 };
386 } // namespace websocket
387 } // namespace crow
388