xref: /openbmc/bmcweb/http/websocket.hpp (revision 3cd7072b)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_body.hpp"
4 #include "http_request.hpp"
5 
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/core/multi_buffer.hpp>
8 #include <boost/beast/websocket.hpp>
9 
10 #include <array>
11 #include <functional>
12 
13 #ifdef BMCWEB_ENABLE_SSL
14 #include <boost/beast/websocket/ssl.hpp>
15 #endif
16 
17 namespace crow
18 {
19 namespace websocket
20 {
21 
22 enum class MessageType
23 {
24     Binary,
25     Text,
26 };
27 
28 struct Connection : std::enable_shared_from_this<Connection>
29 {
30   public:
31     Connection() = default;
32 
33     Connection(const Connection&) = delete;
34     Connection(Connection&&) = delete;
35     Connection& operator=(const Connection&) = delete;
36     Connection& operator=(const Connection&&) = delete;
37 
38     virtual void sendBinary(std::string_view msg) = 0;
39     virtual void sendEx(MessageType type, std::string_view msg,
40                         std::function<void()>&& onDone) = 0;
41     virtual void sendText(std::string_view msg) = 0;
42     virtual void close(std::string_view msg = "quit") = 0;
43     virtual void deferRead() = 0;
44     virtual void resumeRead() = 0;
45     virtual boost::asio::io_context& getIoContext() = 0;
46     virtual ~Connection() = default;
47     virtual boost::urls::url_view url() = 0;
48 };
49 
50 template <typename Adaptor>
51 class ConnectionImpl : public Connection
52 {
53     using self_t = ConnectionImpl<Adaptor>;
54 
55   public:
56     ConnectionImpl(
57         const boost::urls::url_view& urlViewIn,
58         const std::shared_ptr<persistent_data::UserSession>& sessionIn,
59         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
60         std::function<void(Connection&, const std::string&, bool)>
61             messageHandlerIn,
62         std::function<void(crow::websocket::Connection&, std::string_view,
63                            crow::websocket::MessageType type,
64                            std::function<void()>&& whenComplete)>
65             messageExHandlerIn,
66         std::function<void(Connection&, const std::string&)> closeHandlerIn,
67         std::function<void(Connection&)> errorHandlerIn) :
68         uri(urlViewIn),
69         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
70         openHandler(std::move(openHandlerIn)),
71         messageHandler(std::move(messageHandlerIn)),
72         messageExHandler(std::move(messageExHandlerIn)),
73         closeHandler(std::move(closeHandlerIn)),
74         errorHandler(std::move(errorHandlerIn)), session(sessionIn)
75     {
76         /* Turn on the timeouts on websocket stream to server role */
77         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
78             boost::beast::role_type::server));
79         BMCWEB_LOG_DEBUG("Creating new connection {}", logPtr(this));
80     }
81 
82     boost::asio::io_context& getIoContext() override
83     {
84         return static_cast<boost::asio::io_context&>(
85             ws.get_executor().context());
86     }
87 
88     void start(const crow::Request& req)
89     {
90         BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
91 
92         using bf = boost::beast::http::field;
93         std::string protocolHeader{
94             req.getHeaderValue(bf::sec_websocket_protocol)};
95 
96         ws.set_option(boost::beast::websocket::stream_base::decorator(
97             [session{session},
98              protocolHeader](boost::beast::websocket::response_type& m) {
99 
100 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
101             if (session != nullptr)
102             {
103                 // use protocol for csrf checking
104                 if (session->cookieAuth &&
105                     !crow::utility::constantTimeStringCompare(
106                         protocolHeader, session->csrfToken))
107                 {
108                     BMCWEB_LOG_ERROR("Websocket CSRF error");
109                     m.result(boost::beast::http::status::unauthorized);
110                     return;
111                 }
112             }
113 #endif
114             if (!protocolHeader.empty())
115             {
116                 m.insert(bf::sec_websocket_protocol, protocolHeader);
117             }
118 
119             m.insert(bf::strict_transport_security, "max-age=31536000; "
120                                                     "includeSubdomains; "
121                                                     "preload");
122             m.insert(bf::pragma, "no-cache");
123             m.insert(bf::cache_control, "no-Store,no-Cache");
124             m.insert("Content-Security-Policy", "default-src 'self'");
125             m.insert("X-XSS-Protection", "1; "
126                                          "mode=block");
127             m.insert("X-Content-Type-Options", "nosniff");
128         }));
129 
130         // Make a pointer to keep the req alive while we accept it.
131         using Body = boost::beast::http::request<bmcweb::HttpBody>;
132         std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req);
133         Body* ptr = mobile.get();
134         // Perform the websocket upgrade
135         ws.async_accept(*ptr,
136                         std::bind_front(&self_t::acceptDone, this,
137                                         shared_from_this(), std::move(mobile)));
138     }
139 
140     void sendBinary(std::string_view msg) override
141     {
142         ws.binary(true);
143         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
144                                                   boost::asio::buffer(msg)));
145         doWrite();
146     }
147 
148     void sendEx(MessageType type, std::string_view msg,
149                 std::function<void()>&& onDone) override
150     {
151         if (doingWrite)
152         {
153             BMCWEB_LOG_CRITICAL(
154                 "Cannot mix sendEx usage with sendBinary or sendText");
155             onDone();
156             return;
157         }
158         ws.binary(type == MessageType::Binary);
159 
160         ws.async_write(boost::asio::buffer(msg),
161                        [weak(weak_from_this()), onDone{std::move(onDone)}](
162                            const boost::beast::error_code& ec, size_t) {
163             std::shared_ptr<Connection> self = weak.lock();
164             if (!self)
165             {
166                 BMCWEB_LOG_ERROR("Connection went away");
167                 return;
168             }
169 
170             // Call the done handler regardless of whether we
171             // errored, but before we close things out
172             onDone();
173 
174             if (ec)
175             {
176                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
177                 self->close("write error");
178             }
179         });
180     }
181 
182     void sendText(std::string_view msg) override
183     {
184         ws.text(true);
185         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
186                                                   boost::asio::buffer(msg)));
187         doWrite();
188     }
189 
190     void close(std::string_view msg) override
191     {
192         ws.async_close(
193             {boost::beast::websocket::close_code::normal, msg},
194             [self(shared_from_this())](const boost::system::error_code& ec) {
195             if (ec == boost::asio::error::operation_aborted)
196             {
197                 return;
198             }
199             if (ec)
200             {
201                 BMCWEB_LOG_ERROR("Error closing websocket {}", ec);
202                 return;
203             }
204         });
205     }
206 
207     boost::urls::url_view url() override
208     {
209         return uri;
210     }
211 
212     void acceptDone(const std::shared_ptr<Connection>& /*self*/,
213                     const std::unique_ptr<
214                         boost::beast::http::request<bmcweb::HttpBody>>& /*req*/,
215                     const boost::system::error_code& ec)
216     {
217         if (ec)
218         {
219             BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
220             return;
221         }
222         BMCWEB_LOG_DEBUG("Websocket accepted connection");
223 
224         if (openHandler)
225         {
226             openHandler(*this);
227         }
228         doRead();
229     }
230 
231     void deferRead() override
232     {
233         readingDefered = true;
234 
235         // If we're not actively reading, we need to take ownership of
236         // ourselves for a small portion of time, do that, and clear when we
237         // resume.
238         selfOwned = shared_from_this();
239     }
240 
241     void resumeRead() override
242     {
243         readingDefered = false;
244         doRead();
245 
246         // No longer need to keep ourselves alive now that read is active.
247         selfOwned.reset();
248     }
249 
250     void doRead()
251     {
252         if (readingDefered)
253         {
254             return;
255         }
256         ws.async_read(inBuffer, [this, self(shared_from_this())](
257                                     const boost::beast::error_code& ec,
258                                     size_t bytesRead) {
259             if (ec)
260             {
261                 if (ec != boost::beast::websocket::error::closed)
262                 {
263                     BMCWEB_LOG_ERROR("doRead error {}", ec);
264                 }
265                 if (closeHandler)
266                 {
267                     std::string reason{ws.reason().reason.c_str()};
268                     closeHandler(*this, reason);
269                 }
270                 return;
271             }
272 
273             handleMessage(bytesRead);
274         });
275     }
276     void doWrite()
277     {
278         // If we're already doing a write, ignore the request, it will be picked
279         // up when the current write is complete
280         if (doingWrite)
281         {
282             return;
283         }
284 
285         if (outBuffer.size() == 0)
286         {
287             // Done for now
288             return;
289         }
290         doingWrite = true;
291         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
292                                              const boost::beast::error_code& ec,
293                                              size_t bytesSent) {
294             doingWrite = false;
295             outBuffer.consume(bytesSent);
296             if (ec == boost::beast::websocket::error::closed)
297             {
298                 // Do nothing here.  doRead handler will call the
299                 // closeHandler.
300                 close("Write error");
301                 return;
302             }
303             if (ec)
304             {
305                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
306                 return;
307             }
308             doWrite();
309         });
310     }
311 
312   private:
313     void handleMessage(size_t bytesRead)
314     {
315         if (messageExHandler)
316         {
317             // Note, because of the interactions with the read buffers,
318             // this message handler overrides the normal message handler
319             messageExHandler(*this, inString, MessageType::Binary,
320                              [this, self(shared_from_this()), bytesRead]() {
321                 if (self == nullptr)
322                 {
323                     return;
324                 }
325 
326                 inBuffer.consume(bytesRead);
327                 inString.clear();
328 
329                 doRead();
330             });
331             return;
332         }
333 
334         if (messageHandler)
335         {
336             messageHandler(*this, inString, ws.got_text());
337         }
338         inBuffer.consume(bytesRead);
339         inString.clear();
340         doRead();
341     }
342 
343     boost::urls::url uri;
344 
345     boost::beast::websocket::stream<Adaptor, false> ws;
346 
347     bool readingDefered = false;
348     std::string inString;
349     boost::asio::dynamic_string_buffer<std::string::value_type,
350                                        std::string::traits_type,
351                                        std::string::allocator_type>
352         inBuffer;
353 
354     boost::beast::multi_buffer outBuffer;
355     bool doingWrite = false;
356 
357     std::function<void(Connection&)> openHandler;
358     std::function<void(Connection&, const std::string&, bool)> messageHandler;
359     std::function<void(crow::websocket::Connection&, std::string_view,
360                        crow::websocket::MessageType type,
361                        std::function<void()>&& whenComplete)>
362         messageExHandler;
363     std::function<void(Connection&, const std::string&)> closeHandler;
364     std::function<void(Connection&)> errorHandler;
365     std::shared_ptr<persistent_data::UserSession> session;
366 
367     std::shared_ptr<Connection> selfOwned;
368 };
369 } // namespace websocket
370 } // namespace crow
371