xref: /openbmc/bmcweb/http/websocket.hpp (revision 6c58a03e)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_body.hpp"
4 #include "http_request.hpp"
5 
6 #include <boost/asio/buffer.hpp>
7 #include <boost/asio/ssl/error.hpp>
8 #include <boost/beast/core/multi_buffer.hpp>
9 #include <boost/beast/websocket.hpp>
10 #include <boost/beast/websocket/ssl.hpp>
11 
12 #include <array>
13 #include <functional>
14 
15 namespace crow
16 {
17 namespace websocket
18 {
19 
20 enum class MessageType
21 {
22     Binary,
23     Text,
24 };
25 
26 struct Connection : std::enable_shared_from_this<Connection>
27 {
28   public:
29     Connection() = default;
30 
31     Connection(const Connection&) = delete;
32     Connection(Connection&&) = delete;
33     Connection& operator=(const Connection&) = delete;
34     Connection& operator=(const Connection&&) = delete;
35 
36     virtual void sendBinary(std::string_view msg) = 0;
37     virtual void sendEx(MessageType type, std::string_view msg,
38                         std::function<void()>&& onDone) = 0;
39     virtual void sendText(std::string_view msg) = 0;
40     virtual void close(std::string_view msg = "quit") = 0;
41     virtual void deferRead() = 0;
42     virtual void resumeRead() = 0;
43     virtual boost::asio::io_context& getIoContext() = 0;
44     virtual ~Connection() = default;
45     virtual boost::urls::url_view url() = 0;
46 };
47 
48 template <typename Adaptor>
49 class ConnectionImpl : public Connection
50 {
51     using self_t = ConnectionImpl<Adaptor>;
52 
53   public:
54     ConnectionImpl(
55         const boost::urls::url_view& urlViewIn,
56         const std::shared_ptr<persistent_data::UserSession>& sessionIn,
57         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
58         std::function<void(Connection&, const std::string&, bool)>
59             messageHandlerIn,
60         std::function<void(crow::websocket::Connection&, std::string_view,
61                            crow::websocket::MessageType type,
62                            std::function<void()>&& whenComplete)>
63             messageExHandlerIn,
64         std::function<void(Connection&, const std::string&)> closeHandlerIn,
65         std::function<void(Connection&)> errorHandlerIn) :
66         uri(urlViewIn), ws(std::move(adaptorIn)), inBuffer(inString, 131088),
67         openHandler(std::move(openHandlerIn)),
68         messageHandler(std::move(messageHandlerIn)),
69         messageExHandler(std::move(messageExHandlerIn)),
70         closeHandler(std::move(closeHandlerIn)),
71         errorHandler(std::move(errorHandlerIn)), session(sessionIn)
72     {
73         /* Turn on the timeouts on websocket stream to server role */
74         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
75             boost::beast::role_type::server));
76         BMCWEB_LOG_DEBUG("Creating new connection {}", logPtr(this));
77     }
78 
79     boost::asio::io_context& getIoContext() override
80     {
81         return static_cast<boost::asio::io_context&>(
82             ws.get_executor().context());
83     }
84 
85     void start(const crow::Request& req)
86     {
87         BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
88 
89         using bf = boost::beast::http::field;
90         std::string protocolHeader{
91             req.getHeaderValue(bf::sec_websocket_protocol)};
92 
93         ws.set_option(boost::beast::websocket::stream_base::decorator(
94             [session{session},
95              protocolHeader](boost::beast::websocket::response_type& m) {
96                 if constexpr (!BMCWEB_INSECURE_DISABLE_CSRF)
97                 {
98                     if (session != nullptr)
99                     {
100                         // use protocol for csrf checking
101                         if (session->cookieAuth &&
102                             !bmcweb::constantTimeStringCompare(
103                                 protocolHeader, session->csrfToken))
104                         {
105                             BMCWEB_LOG_ERROR("Websocket CSRF error");
106                             m.result(boost::beast::http::status::unauthorized);
107                             return;
108                         }
109                     }
110                 }
111                 if (!protocolHeader.empty())
112                 {
113                     m.insert(bf::sec_websocket_protocol, protocolHeader);
114                 }
115 
116                 m.insert(bf::strict_transport_security,
117                          "max-age=31536000; "
118                          "includeSubdomains; "
119                          "preload");
120                 m.insert(bf::pragma, "no-cache");
121                 m.insert(bf::cache_control, "no-Store,no-Cache");
122                 m.insert("Content-Security-Policy", "default-src 'self'");
123                 m.insert("X-XSS-Protection", "1; "
124                                              "mode=block");
125                 m.insert("X-Content-Type-Options", "nosniff");
126             }));
127 
128         // Make a pointer to keep the req alive while we accept it.
129         using Body = boost::beast::http::request<bmcweb::HttpBody>;
130         std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req);
131         Body* ptr = mobile.get();
132         // Perform the websocket upgrade
133         ws.async_accept(*ptr,
134                         std::bind_front(&self_t::acceptDone, this,
135                                         shared_from_this(), std::move(mobile)));
136     }
137 
138     void sendBinary(std::string_view msg) override
139     {
140         ws.binary(true);
141         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
142                                                   boost::asio::buffer(msg)));
143         doWrite();
144     }
145 
146     void sendEx(MessageType type, std::string_view msg,
147                 std::function<void()>&& onDone) override
148     {
149         if (doingWrite)
150         {
151             BMCWEB_LOG_CRITICAL(
152                 "Cannot mix sendEx usage with sendBinary or sendText");
153             onDone();
154             return;
155         }
156         ws.binary(type == MessageType::Binary);
157 
158         ws.async_write(boost::asio::buffer(msg),
159                        [weak(weak_from_this()), onDone{std::move(onDone)}](
160                            const boost::beast::error_code& ec, size_t) {
161                            std::shared_ptr<Connection> self = weak.lock();
162                            if (!self)
163                            {
164                                BMCWEB_LOG_ERROR("Connection went away");
165                                return;
166                            }
167 
168                            // Call the done handler regardless of whether we
169                            // errored, but before we close things out
170                            onDone();
171 
172                            if (ec)
173                            {
174                                BMCWEB_LOG_ERROR("Error in ws.async_write {}",
175                                                 ec);
176                                self->close("write error");
177                            }
178                        });
179     }
180 
181     void sendText(std::string_view msg) override
182     {
183         ws.text(true);
184         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
185                                                   boost::asio::buffer(msg)));
186         doWrite();
187     }
188 
189     void close(std::string_view msg) override
190     {
191         ws.async_close(
192             {boost::beast::websocket::close_code::normal, msg},
193             [self(shared_from_this())](const boost::system::error_code& ec) {
194                 if (ec == boost::asio::error::operation_aborted)
195                 {
196                     return;
197                 }
198                 if (ec)
199                 {
200                     BMCWEB_LOG_ERROR("Error closing websocket {}", ec);
201                     return;
202                 }
203             });
204     }
205 
206     boost::urls::url_view url() override
207     {
208         return uri;
209     }
210 
211     void acceptDone(const std::shared_ptr<Connection>& /*self*/,
212                     const std::unique_ptr<
213                         boost::beast::http::request<bmcweb::HttpBody>>& /*req*/,
214                     const boost::system::error_code& ec)
215     {
216         if (ec)
217         {
218             BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
219             return;
220         }
221         BMCWEB_LOG_DEBUG("Websocket accepted connection");
222 
223         if (openHandler)
224         {
225             openHandler(*this);
226         }
227         doRead();
228     }
229 
230     void deferRead() override
231     {
232         readingDefered = true;
233 
234         // If we're not actively reading, we need to take ownership of
235         // ourselves for a small portion of time, do that, and clear when we
236         // resume.
237         selfOwned = shared_from_this();
238     }
239 
240     void resumeRead() override
241     {
242         readingDefered = false;
243         doRead();
244 
245         // No longer need to keep ourselves alive now that read is active.
246         selfOwned.reset();
247     }
248 
249     void doRead()
250     {
251         if (readingDefered)
252         {
253             return;
254         }
255         ws.async_read(inBuffer, [this, self(shared_from_this())](
256                                     const boost::beast::error_code& ec,
257                                     size_t bytesRead) {
258             if (ec)
259             {
260                 if (ec != boost::beast::websocket::error::closed &&
261                     ec != boost::asio::error::eof &&
262                     ec != boost::asio::ssl::error::stream_truncated)
263                 {
264                     BMCWEB_LOG_ERROR("doRead error {}", ec);
265                 }
266                 if (closeHandler)
267                 {
268                     std::string reason{ws.reason().reason.c_str()};
269                     closeHandler(*this, reason);
270                 }
271                 return;
272             }
273 
274             handleMessage(bytesRead);
275         });
276     }
277     void doWrite()
278     {
279         // If we're already doing a write, ignore the request, it will be picked
280         // up when the current write is complete
281         if (doingWrite)
282         {
283             return;
284         }
285 
286         if (outBuffer.size() == 0)
287         {
288             // Done for now
289             return;
290         }
291         doingWrite = true;
292         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
293                                              const boost::beast::error_code& ec,
294                                              size_t bytesSent) {
295             doingWrite = false;
296             outBuffer.consume(bytesSent);
297             if (ec == boost::beast::websocket::error::closed)
298             {
299                 // Do nothing here.  doRead handler will call the
300                 // closeHandler.
301                 close("Write error");
302                 return;
303             }
304             if (ec)
305             {
306                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
307                 return;
308             }
309             doWrite();
310         });
311     }
312 
313   private:
314     void handleMessage(size_t bytesRead)
315     {
316         if (messageExHandler)
317         {
318             // Note, because of the interactions with the read buffers,
319             // this message handler overrides the normal message handler
320             messageExHandler(*this, inString, MessageType::Binary,
321                              [this, self(shared_from_this()), bytesRead]() {
322                                  if (self == nullptr)
323                                  {
324                                      return;
325                                  }
326 
327                                  inBuffer.consume(bytesRead);
328                                  inString.clear();
329 
330                                  doRead();
331                              });
332             return;
333         }
334 
335         if (messageHandler)
336         {
337             messageHandler(*this, inString, ws.got_text());
338         }
339         inBuffer.consume(bytesRead);
340         inString.clear();
341         doRead();
342     }
343 
344     boost::urls::url uri;
345 
346     boost::beast::websocket::stream<Adaptor, false> ws;
347 
348     bool readingDefered = false;
349     std::string inString;
350     boost::asio::dynamic_string_buffer<std::string::value_type,
351                                        std::string::traits_type,
352                                        std::string::allocator_type>
353         inBuffer;
354 
355     boost::beast::multi_buffer outBuffer;
356     bool doingWrite = false;
357 
358     std::function<void(Connection&)> openHandler;
359     std::function<void(Connection&, const std::string&, bool)> messageHandler;
360     std::function<void(crow::websocket::Connection&, std::string_view,
361                        crow::websocket::MessageType type,
362                        std::function<void()>&& whenComplete)>
363         messageExHandler;
364     std::function<void(Connection&, const std::string&)> closeHandler;
365     std::function<void(Connection&)> errorHandler;
366     std::shared_ptr<persistent_data::UserSession> session;
367 
368     std::shared_ptr<Connection> selfOwned;
369 };
370 } // namespace websocket
371 } // namespace crow
372