xref: /openbmc/bmcweb/http/websocket.hpp (revision 1873a04f43bb414408d1da9a5a775c05474603d2)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_body.hpp"
4 #include "http_request.hpp"
5 
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/core/multi_buffer.hpp>
8 #include <boost/beast/websocket.hpp>
9 
10 #include <array>
11 #include <functional>
12 
13 #ifdef BMCWEB_ENABLE_SSL
14 #include <boost/beast/websocket/ssl.hpp>
15 #endif
16 
17 namespace crow
18 {
19 namespace websocket
20 {
21 
22 enum class MessageType
23 {
24     Binary,
25     Text,
26 };
27 
28 struct Connection : std::enable_shared_from_this<Connection>
29 {
30   public:
31     Connection() = default;
32 
33     Connection(const Connection&) = delete;
34     Connection(Connection&&) = delete;
35     Connection& operator=(const Connection&) = delete;
36     Connection& operator=(const Connection&&) = delete;
37 
38     virtual void sendBinary(std::string_view msg) = 0;
39     virtual void sendBinary(std::string&& msg) = 0;
40     virtual void sendEx(MessageType type, std::string_view msg,
41                         std::function<void()>&& onDone) = 0;
42     virtual void sendText(std::string_view msg) = 0;
43     virtual void sendText(std::string&& msg) = 0;
44     virtual void close(std::string_view msg = "quit") = 0;
45     virtual void deferRead() = 0;
46     virtual void resumeRead() = 0;
47     virtual boost::asio::io_context& getIoContext() = 0;
48     virtual ~Connection() = default;
49     virtual boost::urls::url_view url() = 0;
50 };
51 
52 template <typename Adaptor>
53 class ConnectionImpl : public Connection
54 {
55     using self_t = ConnectionImpl<Adaptor>;
56 
57   public:
58     ConnectionImpl(
59         const boost::urls::url_view& urlViewIn,
60         const std::shared_ptr<persistent_data::UserSession>& sessionIn,
61         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
62         std::function<void(Connection&, const std::string&, bool)>
63             messageHandlerIn,
64         std::function<void(crow::websocket::Connection&, std::string_view,
65                            crow::websocket::MessageType type,
66                            std::function<void()>&& whenComplete)>
67             messageExHandlerIn,
68         std::function<void(Connection&, const std::string&)> closeHandlerIn,
69         std::function<void(Connection&)> errorHandlerIn) :
70         uri(urlViewIn),
71         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
72         openHandler(std::move(openHandlerIn)),
73         messageHandler(std::move(messageHandlerIn)),
74         messageExHandler(std::move(messageExHandlerIn)),
75         closeHandler(std::move(closeHandlerIn)),
76         errorHandler(std::move(errorHandlerIn)), session(sessionIn)
77     {
78         /* Turn on the timeouts on websocket stream to server role */
79         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
80             boost::beast::role_type::server));
81         BMCWEB_LOG_DEBUG("Creating new connection {}", logPtr(this));
82     }
83 
84     boost::asio::io_context& getIoContext() override
85     {
86         return static_cast<boost::asio::io_context&>(
87             ws.get_executor().context());
88     }
89 
90     void start(const crow::Request& req)
91     {
92         BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
93 
94         using bf = boost::beast::http::field;
95         std::string protocolHeader{
96             req.getHeaderValue(bf::sec_websocket_protocol)};
97 
98         ws.set_option(boost::beast::websocket::stream_base::decorator(
99             [session{session},
100              protocolHeader](boost::beast::websocket::response_type& m) {
101 
102 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
103             if (session != nullptr)
104             {
105                 // use protocol for csrf checking
106                 if (session->cookieAuth &&
107                     !crow::utility::constantTimeStringCompare(
108                         protocolHeader, session->csrfToken))
109                 {
110                     BMCWEB_LOG_ERROR("Websocket CSRF error");
111                     m.result(boost::beast::http::status::unauthorized);
112                     return;
113                 }
114             }
115 #endif
116             if (!protocolHeader.empty())
117             {
118                 m.insert(bf::sec_websocket_protocol, protocolHeader);
119             }
120 
121             m.insert(bf::strict_transport_security, "max-age=31536000; "
122                                                     "includeSubdomains; "
123                                                     "preload");
124             m.insert(bf::pragma, "no-cache");
125             m.insert(bf::cache_control, "no-Store,no-Cache");
126             m.insert("Content-Security-Policy", "default-src 'self'");
127             m.insert("X-XSS-Protection", "1; "
128                                          "mode=block");
129             m.insert("X-Content-Type-Options", "nosniff");
130         }));
131 
132         // Make a pointer to keep the req alive while we accept it.
133         using Body = boost::beast::http::request<bmcweb::HttpBody>;
134         std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req);
135         Body* ptr = mobile.get();
136         // Perform the websocket upgrade
137         ws.async_accept(*ptr,
138                         std::bind_front(&self_t::acceptDone, this,
139                                         shared_from_this(), std::move(mobile)));
140     }
141 
142     void sendBinary(std::string_view msg) override
143     {
144         ws.binary(true);
145         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
146                                                   boost::asio::buffer(msg)));
147         doWrite();
148     }
149 
150     void sendEx(MessageType type, std::string_view msg,
151                 std::function<void()>&& onDone) override
152     {
153         if (doingWrite)
154         {
155             BMCWEB_LOG_CRITICAL(
156                 "Cannot mix sendEx usage with sendBinary or sendText");
157             onDone();
158             return;
159         }
160         ws.binary(type == MessageType::Binary);
161 
162         ws.async_write(boost::asio::buffer(msg),
163                        [weak(weak_from_this()), onDone{std::move(onDone)}](
164                            const boost::beast::error_code& ec, size_t) {
165             std::shared_ptr<Connection> self = weak.lock();
166             if (!self)
167             {
168                 BMCWEB_LOG_ERROR("Connection went away");
169                 return;
170             }
171 
172             // Call the done handler regardless of whether we
173             // errored, but before we close things out
174             onDone();
175 
176             if (ec)
177             {
178                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
179                 self->close("write error");
180             }
181         });
182     }
183 
184     void sendBinary(std::string&& msg) override
185     {
186         ws.binary(true);
187         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
188                                                   boost::asio::buffer(msg)));
189         doWrite();
190     }
191 
192     void sendText(std::string_view msg) override
193     {
194         ws.text(true);
195         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
196                                                   boost::asio::buffer(msg)));
197         doWrite();
198     }
199 
200     void sendText(std::string&& msg) override
201     {
202         ws.text(true);
203         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
204                                                   boost::asio::buffer(msg)));
205         doWrite();
206     }
207 
208     void close(std::string_view msg) override
209     {
210         ws.async_close(
211             {boost::beast::websocket::close_code::normal, msg},
212             [self(shared_from_this())](const boost::system::error_code& ec) {
213             if (ec == boost::asio::error::operation_aborted)
214             {
215                 return;
216             }
217             if (ec)
218             {
219                 BMCWEB_LOG_ERROR("Error closing websocket {}", ec);
220                 return;
221             }
222         });
223     }
224 
225     boost::urls::url_view url() override
226     {
227         return uri;
228     }
229 
230     void acceptDone(const std::shared_ptr<Connection>& /*self*/,
231                     const std::unique_ptr<
232                         boost::beast::http::request<bmcweb::HttpBody>>& /*req*/,
233                     const boost::system::error_code& ec)
234     {
235         if (ec)
236         {
237             BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
238             return;
239         }
240         BMCWEB_LOG_DEBUG("Websocket accepted connection");
241 
242         if (openHandler)
243         {
244             openHandler(*this);
245         }
246         doRead();
247     }
248 
249     void deferRead() override
250     {
251         readingDefered = true;
252 
253         // If we're not actively reading, we need to take ownership of
254         // ourselves for a small portion of time, do that, and clear when we
255         // resume.
256         selfOwned = shared_from_this();
257     }
258 
259     void resumeRead() override
260     {
261         readingDefered = false;
262         doRead();
263 
264         // No longer need to keep ourselves alive now that read is active.
265         selfOwned.reset();
266     }
267 
268     void doRead()
269     {
270         if (readingDefered)
271         {
272             return;
273         }
274         ws.async_read(inBuffer, [this, self(shared_from_this())](
275                                     const boost::beast::error_code& ec,
276                                     size_t bytesRead) {
277             if (ec)
278             {
279                 if (ec != boost::beast::websocket::error::closed)
280                 {
281                     BMCWEB_LOG_ERROR("doRead error {}", ec);
282                 }
283                 if (closeHandler)
284                 {
285                     std::string reason{ws.reason().reason.c_str()};
286                     closeHandler(*this, reason);
287                 }
288                 return;
289             }
290 
291             handleMessage(bytesRead);
292         });
293     }
294     void doWrite()
295     {
296         // If we're already doing a write, ignore the request, it will be picked
297         // up when the current write is complete
298         if (doingWrite)
299         {
300             return;
301         }
302 
303         if (outBuffer.size() == 0)
304         {
305             // Done for now
306             return;
307         }
308         doingWrite = true;
309         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
310                                              const boost::beast::error_code& ec,
311                                              size_t bytesSent) {
312             doingWrite = false;
313             outBuffer.consume(bytesSent);
314             if (ec == boost::beast::websocket::error::closed)
315             {
316                 // Do nothing here.  doRead handler will call the
317                 // closeHandler.
318                 close("Write error");
319                 return;
320             }
321             if (ec)
322             {
323                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
324                 return;
325             }
326             doWrite();
327         });
328     }
329 
330   private:
331     void handleMessage(size_t bytesRead)
332     {
333         if (messageExHandler)
334         {
335             // Note, because of the interactions with the read buffers,
336             // this message handler overrides the normal message handler
337             messageExHandler(*this, inString, MessageType::Binary,
338                              [this, self(shared_from_this()), bytesRead]() {
339                 if (self == nullptr)
340                 {
341                     return;
342                 }
343 
344                 inBuffer.consume(bytesRead);
345                 inString.clear();
346 
347                 doRead();
348             });
349             return;
350         }
351 
352         if (messageHandler)
353         {
354             messageHandler(*this, inString, ws.got_text());
355         }
356         inBuffer.consume(bytesRead);
357         inString.clear();
358         doRead();
359     }
360 
361     boost::urls::url uri;
362 
363     boost::beast::websocket::stream<Adaptor, false> ws;
364 
365     bool readingDefered = false;
366     std::string inString;
367     boost::asio::dynamic_string_buffer<std::string::value_type,
368                                        std::string::traits_type,
369                                        std::string::allocator_type>
370         inBuffer;
371 
372     boost::beast::multi_buffer outBuffer;
373     bool doingWrite = false;
374 
375     std::function<void(Connection&)> openHandler;
376     std::function<void(Connection&, const std::string&, bool)> messageHandler;
377     std::function<void(crow::websocket::Connection&, std::string_view,
378                        crow::websocket::MessageType type,
379                        std::function<void()>&& whenComplete)>
380         messageExHandler;
381     std::function<void(Connection&, const std::string&)> closeHandler;
382     std::function<void(Connection&)> errorHandler;
383     std::shared_ptr<persistent_data::UserSession> session;
384 
385     std::shared_ptr<Connection> selfOwned;
386 };
387 } // namespace websocket
388 } // namespace crow
389