xref: /openbmc/bmcweb/http/websocket.hpp (revision a529a6aa)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_body.hpp"
4 #include "http_request.hpp"
5 
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/core/multi_buffer.hpp>
8 #include <boost/beast/websocket.hpp>
9 #include <boost/beast/websocket/ssl.hpp>
10 
11 #include <array>
12 #include <functional>
13 
14 namespace crow
15 {
16 namespace websocket
17 {
18 
19 enum class MessageType
20 {
21     Binary,
22     Text,
23 };
24 
25 struct Connection : std::enable_shared_from_this<Connection>
26 {
27   public:
28     Connection() = default;
29 
30     Connection(const Connection&) = delete;
31     Connection(Connection&&) = delete;
32     Connection& operator=(const Connection&) = delete;
33     Connection& operator=(const Connection&&) = delete;
34 
35     virtual void sendBinary(std::string_view msg) = 0;
36     virtual void sendEx(MessageType type, std::string_view msg,
37                         std::function<void()>&& onDone) = 0;
38     virtual void sendText(std::string_view msg) = 0;
39     virtual void close(std::string_view msg = "quit") = 0;
40     virtual void deferRead() = 0;
41     virtual void resumeRead() = 0;
42     virtual boost::asio::io_context& getIoContext() = 0;
43     virtual ~Connection() = default;
44     virtual boost::urls::url_view url() = 0;
45 };
46 
47 template <typename Adaptor>
48 class ConnectionImpl : public Connection
49 {
50     using self_t = ConnectionImpl<Adaptor>;
51 
52   public:
53     ConnectionImpl(
54         const boost::urls::url_view& urlViewIn,
55         const std::shared_ptr<persistent_data::UserSession>& sessionIn,
56         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
57         std::function<void(Connection&, const std::string&, bool)>
58             messageHandlerIn,
59         std::function<void(crow::websocket::Connection&, std::string_view,
60                            crow::websocket::MessageType type,
61                            std::function<void()>&& whenComplete)>
62             messageExHandlerIn,
63         std::function<void(Connection&, const std::string&)> closeHandlerIn,
64         std::function<void(Connection&)> errorHandlerIn) :
65         uri(urlViewIn),
66         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
67         openHandler(std::move(openHandlerIn)),
68         messageHandler(std::move(messageHandlerIn)),
69         messageExHandler(std::move(messageExHandlerIn)),
70         closeHandler(std::move(closeHandlerIn)),
71         errorHandler(std::move(errorHandlerIn)), session(sessionIn)
72     {
73         /* Turn on the timeouts on websocket stream to server role */
74         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
75             boost::beast::role_type::server));
76         BMCWEB_LOG_DEBUG("Creating new connection {}", logPtr(this));
77     }
78 
79     boost::asio::io_context& getIoContext() override
80     {
81         return static_cast<boost::asio::io_context&>(
82             ws.get_executor().context());
83     }
84 
85     void start(const crow::Request& req)
86     {
87         BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
88 
89         using bf = boost::beast::http::field;
90         std::string protocolHeader{
91             req.getHeaderValue(bf::sec_websocket_protocol)};
92 
93         ws.set_option(boost::beast::websocket::stream_base::decorator(
94             [session{session},
95              protocolHeader](boost::beast::websocket::response_type& m) {
96             if constexpr (!BMCWEB_INSECURE_DISABLE_CSRF)
97             {
98                 if (session != nullptr)
99                 {
100                     // use protocol for csrf checking
101                     if (session->cookieAuth &&
102                         !crow::utility::constantTimeStringCompare(
103                             protocolHeader, session->csrfToken))
104                     {
105                         BMCWEB_LOG_ERROR("Websocket CSRF error");
106                         m.result(boost::beast::http::status::unauthorized);
107                         return;
108                     }
109                 }
110             }
111             if (!protocolHeader.empty())
112             {
113                 m.insert(bf::sec_websocket_protocol, protocolHeader);
114             }
115 
116             m.insert(bf::strict_transport_security, "max-age=31536000; "
117                                                     "includeSubdomains; "
118                                                     "preload");
119             m.insert(bf::pragma, "no-cache");
120             m.insert(bf::cache_control, "no-Store,no-Cache");
121             m.insert("Content-Security-Policy", "default-src 'self'");
122             m.insert("X-XSS-Protection", "1; "
123                                          "mode=block");
124             m.insert("X-Content-Type-Options", "nosniff");
125         }));
126 
127         // Make a pointer to keep the req alive while we accept it.
128         using Body = boost::beast::http::request<bmcweb::HttpBody>;
129         std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req);
130         Body* ptr = mobile.get();
131         // Perform the websocket upgrade
132         ws.async_accept(*ptr,
133                         std::bind_front(&self_t::acceptDone, this,
134                                         shared_from_this(), std::move(mobile)));
135     }
136 
137     void sendBinary(std::string_view msg) override
138     {
139         ws.binary(true);
140         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
141                                                   boost::asio::buffer(msg)));
142         doWrite();
143     }
144 
145     void sendEx(MessageType type, std::string_view msg,
146                 std::function<void()>&& onDone) override
147     {
148         if (doingWrite)
149         {
150             BMCWEB_LOG_CRITICAL(
151                 "Cannot mix sendEx usage with sendBinary or sendText");
152             onDone();
153             return;
154         }
155         ws.binary(type == MessageType::Binary);
156 
157         ws.async_write(boost::asio::buffer(msg),
158                        [weak(weak_from_this()), onDone{std::move(onDone)}](
159                            const boost::beast::error_code& ec, size_t) {
160             std::shared_ptr<Connection> self = weak.lock();
161             if (!self)
162             {
163                 BMCWEB_LOG_ERROR("Connection went away");
164                 return;
165             }
166 
167             // Call the done handler regardless of whether we
168             // errored, but before we close things out
169             onDone();
170 
171             if (ec)
172             {
173                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
174                 self->close("write error");
175             }
176         });
177     }
178 
179     void sendText(std::string_view msg) override
180     {
181         ws.text(true);
182         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
183                                                   boost::asio::buffer(msg)));
184         doWrite();
185     }
186 
187     void close(std::string_view msg) override
188     {
189         ws.async_close(
190             {boost::beast::websocket::close_code::normal, msg},
191             [self(shared_from_this())](const boost::system::error_code& ec) {
192             if (ec == boost::asio::error::operation_aborted)
193             {
194                 return;
195             }
196             if (ec)
197             {
198                 BMCWEB_LOG_ERROR("Error closing websocket {}", ec);
199                 return;
200             }
201         });
202     }
203 
204     boost::urls::url_view url() override
205     {
206         return uri;
207     }
208 
209     void acceptDone(const std::shared_ptr<Connection>& /*self*/,
210                     const std::unique_ptr<
211                         boost::beast::http::request<bmcweb::HttpBody>>& /*req*/,
212                     const boost::system::error_code& ec)
213     {
214         if (ec)
215         {
216             BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
217             return;
218         }
219         BMCWEB_LOG_DEBUG("Websocket accepted connection");
220 
221         if (openHandler)
222         {
223             openHandler(*this);
224         }
225         doRead();
226     }
227 
228     void deferRead() override
229     {
230         readingDefered = true;
231 
232         // If we're not actively reading, we need to take ownership of
233         // ourselves for a small portion of time, do that, and clear when we
234         // resume.
235         selfOwned = shared_from_this();
236     }
237 
238     void resumeRead() override
239     {
240         readingDefered = false;
241         doRead();
242 
243         // No longer need to keep ourselves alive now that read is active.
244         selfOwned.reset();
245     }
246 
247     void doRead()
248     {
249         if (readingDefered)
250         {
251             return;
252         }
253         ws.async_read(inBuffer, [this, self(shared_from_this())](
254                                     const boost::beast::error_code& ec,
255                                     size_t bytesRead) {
256             if (ec)
257             {
258                 if (ec != boost::beast::websocket::error::closed)
259                 {
260                     BMCWEB_LOG_ERROR("doRead error {}", ec);
261                 }
262                 if (closeHandler)
263                 {
264                     std::string reason{ws.reason().reason.c_str()};
265                     closeHandler(*this, reason);
266                 }
267                 return;
268             }
269 
270             handleMessage(bytesRead);
271         });
272     }
273     void doWrite()
274     {
275         // If we're already doing a write, ignore the request, it will be picked
276         // up when the current write is complete
277         if (doingWrite)
278         {
279             return;
280         }
281 
282         if (outBuffer.size() == 0)
283         {
284             // Done for now
285             return;
286         }
287         doingWrite = true;
288         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
289                                              const boost::beast::error_code& ec,
290                                              size_t bytesSent) {
291             doingWrite = false;
292             outBuffer.consume(bytesSent);
293             if (ec == boost::beast::websocket::error::closed)
294             {
295                 // Do nothing here.  doRead handler will call the
296                 // closeHandler.
297                 close("Write error");
298                 return;
299             }
300             if (ec)
301             {
302                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
303                 return;
304             }
305             doWrite();
306         });
307     }
308 
309   private:
310     void handleMessage(size_t bytesRead)
311     {
312         if (messageExHandler)
313         {
314             // Note, because of the interactions with the read buffers,
315             // this message handler overrides the normal message handler
316             messageExHandler(*this, inString, MessageType::Binary,
317                              [this, self(shared_from_this()), bytesRead]() {
318                 if (self == nullptr)
319                 {
320                     return;
321                 }
322 
323                 inBuffer.consume(bytesRead);
324                 inString.clear();
325 
326                 doRead();
327             });
328             return;
329         }
330 
331         if (messageHandler)
332         {
333             messageHandler(*this, inString, ws.got_text());
334         }
335         inBuffer.consume(bytesRead);
336         inString.clear();
337         doRead();
338     }
339 
340     boost::urls::url uri;
341 
342     boost::beast::websocket::stream<Adaptor, false> ws;
343 
344     bool readingDefered = false;
345     std::string inString;
346     boost::asio::dynamic_string_buffer<std::string::value_type,
347                                        std::string::traits_type,
348                                        std::string::allocator_type>
349         inBuffer;
350 
351     boost::beast::multi_buffer outBuffer;
352     bool doingWrite = false;
353 
354     std::function<void(Connection&)> openHandler;
355     std::function<void(Connection&, const std::string&, bool)> messageHandler;
356     std::function<void(crow::websocket::Connection&, std::string_view,
357                        crow::websocket::MessageType type,
358                        std::function<void()>&& whenComplete)>
359         messageExHandler;
360     std::function<void(Connection&, const std::string&)> closeHandler;
361     std::function<void(Connection&)> errorHandler;
362     std::shared_ptr<persistent_data::UserSession> session;
363 
364     std::shared_ptr<Connection> selfOwned;
365 };
366 } // namespace websocket
367 } // namespace crow
368