xref: /openbmc/bmcweb/http/websocket.hpp (revision c6178aba)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_body.hpp"
4 #include "http_request.hpp"
5 
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/core/multi_buffer.hpp>
8 #include <boost/beast/websocket.hpp>
9 #include <boost/beast/websocket/ssl.hpp>
10 
11 #include <array>
12 #include <functional>
13 
14 namespace crow
15 {
16 namespace websocket
17 {
18 
19 enum class MessageType
20 {
21     Binary,
22     Text,
23 };
24 
25 struct Connection : std::enable_shared_from_this<Connection>
26 {
27   public:
28     Connection() = default;
29 
30     Connection(const Connection&) = delete;
31     Connection(Connection&&) = delete;
32     Connection& operator=(const Connection&) = delete;
33     Connection& operator=(const Connection&&) = delete;
34 
35     virtual void sendBinary(std::string_view msg) = 0;
36     virtual void sendEx(MessageType type, std::string_view msg,
37                         std::function<void()>&& onDone) = 0;
38     virtual void sendText(std::string_view msg) = 0;
39     virtual void close(std::string_view msg = "quit") = 0;
40     virtual void deferRead() = 0;
41     virtual void resumeRead() = 0;
42     virtual boost::asio::io_context& getIoContext() = 0;
43     virtual ~Connection() = default;
44     virtual boost::urls::url_view url() = 0;
45 };
46 
47 template <typename Adaptor>
48 class ConnectionImpl : public Connection
49 {
50     using self_t = ConnectionImpl<Adaptor>;
51 
52   public:
53     ConnectionImpl(
54         const boost::urls::url_view& urlViewIn,
55         const std::shared_ptr<persistent_data::UserSession>& sessionIn,
56         Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
57         std::function<void(Connection&, const std::string&, bool)>
58             messageHandlerIn,
59         std::function<void(crow::websocket::Connection&, std::string_view,
60                            crow::websocket::MessageType type,
61                            std::function<void()>&& whenComplete)>
62             messageExHandlerIn,
63         std::function<void(Connection&, const std::string&)> closeHandlerIn,
64         std::function<void(Connection&)> errorHandlerIn) :
65         uri(urlViewIn), ws(std::move(adaptorIn)), inBuffer(inString, 131088),
66         openHandler(std::move(openHandlerIn)),
67         messageHandler(std::move(messageHandlerIn)),
68         messageExHandler(std::move(messageExHandlerIn)),
69         closeHandler(std::move(closeHandlerIn)),
70         errorHandler(std::move(errorHandlerIn)), session(sessionIn)
71     {
72         /* Turn on the timeouts on websocket stream to server role */
73         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
74             boost::beast::role_type::server));
75         BMCWEB_LOG_DEBUG("Creating new connection {}", logPtr(this));
76     }
77 
78     boost::asio::io_context& getIoContext() override
79     {
80         return static_cast<boost::asio::io_context&>(
81             ws.get_executor().context());
82     }
83 
84     void start(const crow::Request& req)
85     {
86         BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
87 
88         using bf = boost::beast::http::field;
89         std::string protocolHeader{
90             req.getHeaderValue(bf::sec_websocket_protocol)};
91 
92         ws.set_option(boost::beast::websocket::stream_base::decorator(
93             [session{session},
94              protocolHeader](boost::beast::websocket::response_type& m) {
95                 if constexpr (!BMCWEB_INSECURE_DISABLE_CSRF)
96                 {
97                     if (session != nullptr)
98                     {
99                         // use protocol for csrf checking
100                         if (session->cookieAuth &&
101                             !bmcweb::constantTimeStringCompare(
102                                 protocolHeader, session->csrfToken))
103                         {
104                             BMCWEB_LOG_ERROR("Websocket CSRF error");
105                             m.result(boost::beast::http::status::unauthorized);
106                             return;
107                         }
108                     }
109                 }
110                 if (!protocolHeader.empty())
111                 {
112                     m.insert(bf::sec_websocket_protocol, protocolHeader);
113                 }
114 
115                 m.insert(bf::strict_transport_security,
116                          "max-age=31536000; "
117                          "includeSubdomains; "
118                          "preload");
119                 m.insert(bf::pragma, "no-cache");
120                 m.insert(bf::cache_control, "no-Store,no-Cache");
121                 m.insert("Content-Security-Policy", "default-src 'self'");
122                 m.insert("X-XSS-Protection", "1; "
123                                              "mode=block");
124                 m.insert("X-Content-Type-Options", "nosniff");
125             }));
126 
127         // Make a pointer to keep the req alive while we accept it.
128         using Body = boost::beast::http::request<bmcweb::HttpBody>;
129         std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req);
130         Body* ptr = mobile.get();
131         // Perform the websocket upgrade
132         ws.async_accept(*ptr,
133                         std::bind_front(&self_t::acceptDone, this,
134                                         shared_from_this(), std::move(mobile)));
135     }
136 
137     void sendBinary(std::string_view msg) override
138     {
139         ws.binary(true);
140         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
141                                                   boost::asio::buffer(msg)));
142         doWrite();
143     }
144 
145     void sendEx(MessageType type, std::string_view msg,
146                 std::function<void()>&& onDone) override
147     {
148         if (doingWrite)
149         {
150             BMCWEB_LOG_CRITICAL(
151                 "Cannot mix sendEx usage with sendBinary or sendText");
152             onDone();
153             return;
154         }
155         ws.binary(type == MessageType::Binary);
156 
157         ws.async_write(boost::asio::buffer(msg),
158                        [weak(weak_from_this()), onDone{std::move(onDone)}](
159                            const boost::beast::error_code& ec, size_t) {
160                            std::shared_ptr<Connection> self = weak.lock();
161                            if (!self)
162                            {
163                                BMCWEB_LOG_ERROR("Connection went away");
164                                return;
165                            }
166 
167                            // Call the done handler regardless of whether we
168                            // errored, but before we close things out
169                            onDone();
170 
171                            if (ec)
172                            {
173                                BMCWEB_LOG_ERROR("Error in ws.async_write {}",
174                                                 ec);
175                                self->close("write error");
176                            }
177                        });
178     }
179 
180     void sendText(std::string_view msg) override
181     {
182         ws.text(true);
183         outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
184                                                   boost::asio::buffer(msg)));
185         doWrite();
186     }
187 
188     void close(std::string_view msg) override
189     {
190         ws.async_close(
191             {boost::beast::websocket::close_code::normal, msg},
192             [self(shared_from_this())](const boost::system::error_code& ec) {
193                 if (ec == boost::asio::error::operation_aborted)
194                 {
195                     return;
196                 }
197                 if (ec)
198                 {
199                     BMCWEB_LOG_ERROR("Error closing websocket {}", ec);
200                     return;
201                 }
202             });
203     }
204 
205     boost::urls::url_view url() override
206     {
207         return uri;
208     }
209 
210     void acceptDone(const std::shared_ptr<Connection>& /*self*/,
211                     const std::unique_ptr<
212                         boost::beast::http::request<bmcweb::HttpBody>>& /*req*/,
213                     const boost::system::error_code& ec)
214     {
215         if (ec)
216         {
217             BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
218             return;
219         }
220         BMCWEB_LOG_DEBUG("Websocket accepted connection");
221 
222         if (openHandler)
223         {
224             openHandler(*this);
225         }
226         doRead();
227     }
228 
229     void deferRead() override
230     {
231         readingDefered = true;
232 
233         // If we're not actively reading, we need to take ownership of
234         // ourselves for a small portion of time, do that, and clear when we
235         // resume.
236         selfOwned = shared_from_this();
237     }
238 
239     void resumeRead() override
240     {
241         readingDefered = false;
242         doRead();
243 
244         // No longer need to keep ourselves alive now that read is active.
245         selfOwned.reset();
246     }
247 
248     void doRead()
249     {
250         if (readingDefered)
251         {
252             return;
253         }
254         ws.async_read(inBuffer, [this, self(shared_from_this())](
255                                     const boost::beast::error_code& ec,
256                                     size_t bytesRead) {
257             if (ec)
258             {
259                 if (ec != boost::beast::websocket::error::closed)
260                 {
261                     BMCWEB_LOG_ERROR("doRead error {}", ec);
262                 }
263                 if (closeHandler)
264                 {
265                     std::string reason{ws.reason().reason.c_str()};
266                     closeHandler(*this, reason);
267                 }
268                 return;
269             }
270 
271             handleMessage(bytesRead);
272         });
273     }
274     void doWrite()
275     {
276         // If we're already doing a write, ignore the request, it will be picked
277         // up when the current write is complete
278         if (doingWrite)
279         {
280             return;
281         }
282 
283         if (outBuffer.size() == 0)
284         {
285             // Done for now
286             return;
287         }
288         doingWrite = true;
289         ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
290                                              const boost::beast::error_code& ec,
291                                              size_t bytesSent) {
292             doingWrite = false;
293             outBuffer.consume(bytesSent);
294             if (ec == boost::beast::websocket::error::closed)
295             {
296                 // Do nothing here.  doRead handler will call the
297                 // closeHandler.
298                 close("Write error");
299                 return;
300             }
301             if (ec)
302             {
303                 BMCWEB_LOG_ERROR("Error in ws.async_write {}", ec);
304                 return;
305             }
306             doWrite();
307         });
308     }
309 
310   private:
311     void handleMessage(size_t bytesRead)
312     {
313         if (messageExHandler)
314         {
315             // Note, because of the interactions with the read buffers,
316             // this message handler overrides the normal message handler
317             messageExHandler(*this, inString, MessageType::Binary,
318                              [this, self(shared_from_this()), bytesRead]() {
319                                  if (self == nullptr)
320                                  {
321                                      return;
322                                  }
323 
324                                  inBuffer.consume(bytesRead);
325                                  inString.clear();
326 
327                                  doRead();
328                              });
329             return;
330         }
331 
332         if (messageHandler)
333         {
334             messageHandler(*this, inString, ws.got_text());
335         }
336         inBuffer.consume(bytesRead);
337         inString.clear();
338         doRead();
339     }
340 
341     boost::urls::url uri;
342 
343     boost::beast::websocket::stream<Adaptor, false> ws;
344 
345     bool readingDefered = false;
346     std::string inString;
347     boost::asio::dynamic_string_buffer<std::string::value_type,
348                                        std::string::traits_type,
349                                        std::string::allocator_type>
350         inBuffer;
351 
352     boost::beast::multi_buffer outBuffer;
353     bool doingWrite = false;
354 
355     std::function<void(Connection&)> openHandler;
356     std::function<void(Connection&, const std::string&, bool)> messageHandler;
357     std::function<void(crow::websocket::Connection&, std::string_view,
358                        crow::websocket::MessageType type,
359                        std::function<void()>&& whenComplete)>
360         messageExHandler;
361     std::function<void(Connection&, const std::string&)> closeHandler;
362     std::function<void(Connection&)> errorHandler;
363     std::shared_ptr<persistent_data::UserSession> session;
364 
365     std::shared_ptr<Connection> selfOwned;
366 };
367 } // namespace websocket
368 } // namespace crow
369