xref: /openbmc/bmcweb/http/websocket.hpp (revision ed76121b)
1 #pragma once
2 #include "http_request.hpp"
3 
4 #include <async_resp.hpp>
5 #include <boost/asio/buffer.hpp>
6 #include <boost/beast/websocket.hpp>
7 
8 #include <array>
9 #include <functional>
10 
11 #ifdef BMCWEB_ENABLE_SSL
12 #include <boost/beast/websocket/ssl.hpp>
13 #endif
14 
15 namespace crow
16 {
17 namespace websocket
18 {
19 
20 struct Connection : std::enable_shared_from_this<Connection>
21 {
22   public:
23     explicit Connection(const crow::Request& reqIn) :
24         req(reqIn.req), userdataPtr(nullptr)
25     {}
26 
27     explicit Connection(const crow::Request& reqIn, std::string user) :
28         req(reqIn.req), userName{std::move(user)}, userdataPtr(nullptr)
29     {}
30 
31     Connection(const Connection&) = delete;
32     Connection(Connection&&) = delete;
33     Connection& operator=(const Connection&) = delete;
34     Connection& operator=(const Connection&&) = delete;
35 
36     virtual void sendBinary(std::string_view msg) = 0;
37     virtual void sendBinary(std::string&& msg) = 0;
38     virtual void sendText(std::string_view msg) = 0;
39     virtual void sendText(std::string&& msg) = 0;
40     virtual void close(std::string_view msg = "quit") = 0;
41     virtual boost::asio::io_context& getIoContext() = 0;
42     virtual ~Connection() = default;
43 
44     void userdata(void* u)
45     {
46         userdataPtr = u;
47     }
48     void* userdata()
49     {
50         return userdataPtr;
51     }
52 
53     const std::string& getUserName() const
54     {
55         return userName;
56     }
57 
58     boost::beast::http::request<boost::beast::http::string_body> req;
59     crow::Response res;
60 
61   private:
62     std::string userName{};
63     void* userdataPtr;
64 };
65 
66 template <typename Adaptor>
67 class ConnectionImpl : public Connection
68 {
69   public:
70     ConnectionImpl(
71         const crow::Request& reqIn, Adaptor adaptorIn,
72         std::function<void(Connection&)> openHandlerIn,
73         std::function<void(Connection&, const std::string&, bool)>
74             messageHandlerIn,
75         std::function<void(Connection&, const std::string&)> closeHandlerIn,
76         std::function<void(Connection&)> errorHandlerIn) :
77         Connection(reqIn, reqIn.session == nullptr ? std::string{}
78                                                    : reqIn.session->username),
79         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
80         openHandler(std::move(openHandlerIn)),
81         messageHandler(std::move(messageHandlerIn)),
82         closeHandler(std::move(closeHandlerIn)),
83         errorHandler(std::move(errorHandlerIn)), session(reqIn.session)
84     {
85         /* Turn on the timeouts on websocket stream to server role */
86         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
87             boost::beast::role_type::server));
88         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
89     }
90 
91     boost::asio::io_context& getIoContext() override
92     {
93         return static_cast<boost::asio::io_context&>(
94             ws.get_executor().context());
95     }
96 
97     void start()
98     {
99         BMCWEB_LOG_DEBUG << "starting connection " << this;
100 
101         using bf = boost::beast::http::field;
102 
103         std::string_view protocol = req[bf::sec_websocket_protocol];
104 
105         ws.set_option(boost::beast::websocket::stream_base::decorator(
106             [session{session}, protocol{std::string(protocol)}](
107                 boost::beast::websocket::response_type& m) {
108 
109 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
110             if (session != nullptr)
111             {
112                 // use protocol for csrf checking
113                 if (session->cookieAuth &&
114                     !crow::utility::constantTimeStringCompare(
115                         protocol, session->csrfToken))
116                 {
117                     BMCWEB_LOG_ERROR << "Websocket CSRF error";
118                     m.result(boost::beast::http::status::unauthorized);
119                     return;
120                 }
121             }
122 #endif
123             if (!protocol.empty())
124             {
125                 m.insert(bf::sec_websocket_protocol, protocol);
126             }
127 
128             m.insert(bf::strict_transport_security, "max-age=31536000; "
129                                                     "includeSubdomains; "
130                                                     "preload");
131             m.insert(bf::pragma, "no-cache");
132             m.insert(bf::cache_control, "no-Store,no-Cache");
133             m.insert("Content-Security-Policy", "default-src 'self'");
134             m.insert("X-XSS-Protection", "1; "
135                                          "mode=block");
136             m.insert("X-Content-Type-Options", "nosniff");
137         }));
138 
139         // Perform the websocket upgrade
140         ws.async_accept(req, [this, self(shared_from_this())](
141                                  boost::system::error_code ec) {
142             if (ec)
143             {
144                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
145                 return;
146             }
147             acceptDone();
148         });
149     }
150 
151     void sendBinary(const std::string_view msg) override
152     {
153         ws.binary(true);
154         outBuffer.emplace_back(msg);
155         doWrite();
156     }
157 
158     void sendBinary(std::string&& msg) override
159     {
160         ws.binary(true);
161         outBuffer.emplace_back(std::move(msg));
162         doWrite();
163     }
164 
165     void sendText(const std::string_view msg) override
166     {
167         ws.text(true);
168         outBuffer.emplace_back(msg);
169         doWrite();
170     }
171 
172     void sendText(std::string&& msg) override
173     {
174         ws.text(true);
175         outBuffer.emplace_back(std::move(msg));
176         doWrite();
177     }
178 
179     void close(const std::string_view msg) override
180     {
181         ws.async_close(
182             {boost::beast::websocket::close_code::normal, msg},
183             [self(shared_from_this())](boost::system::error_code ec) {
184             if (ec == boost::asio::error::operation_aborted)
185             {
186                 return;
187             }
188             if (ec)
189             {
190                 BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
191                 return;
192             }
193             });
194     }
195 
196     void acceptDone()
197     {
198         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
199 
200         doRead();
201 
202         if (openHandler)
203         {
204             openHandler(*this);
205         }
206     }
207 
208     void doRead()
209     {
210         ws.async_read(inBuffer,
211                       [this, self(shared_from_this())](
212                           boost::beast::error_code ec, std::size_t bytesRead) {
213             if (ec)
214             {
215                 if (ec != boost::beast::websocket::error::closed)
216                 {
217                     BMCWEB_LOG_ERROR << "doRead error " << ec;
218                 }
219                 if (closeHandler)
220                 {
221                     std::string_view reason = ws.reason().reason;
222                     closeHandler(*this, std::string(reason));
223                 }
224                 return;
225             }
226             if (messageHandler)
227             {
228                 messageHandler(*this, inString, ws.got_text());
229             }
230             inBuffer.consume(bytesRead);
231             inString.clear();
232             doRead();
233         });
234     }
235 
236     void doWrite()
237     {
238         // If we're already doing a write, ignore the request, it will be picked
239         // up when the current write is complete
240         if (doingWrite)
241         {
242             return;
243         }
244 
245         if (outBuffer.empty())
246         {
247             // Done for now
248             return;
249         }
250         doingWrite = true;
251         ws.async_write(boost::asio::buffer(outBuffer.front()),
252                        [this, self(shared_from_this())](
253                            boost::beast::error_code ec, std::size_t) {
254             doingWrite = false;
255             outBuffer.erase(outBuffer.begin());
256             if (ec == boost::beast::websocket::error::closed)
257             {
258                 // Do nothing here.  doRead handler will call the
259                 // closeHandler.
260                 close("Write error");
261                 return;
262             }
263             if (ec)
264             {
265                 BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
266                 return;
267             }
268             doWrite();
269         });
270     }
271 
272   private:
273     boost::beast::websocket::stream<Adaptor, false> ws;
274 
275     std::string inString;
276     boost::asio::dynamic_string_buffer<std::string::value_type,
277                                        std::string::traits_type,
278                                        std::string::allocator_type>
279         inBuffer;
280     std::vector<std::string> outBuffer;
281     bool doingWrite = false;
282 
283     std::function<void(Connection&)> openHandler;
284     std::function<void(Connection&, const std::string&, bool)> messageHandler;
285     std::function<void(Connection&, const std::string&)> closeHandler;
286     std::function<void(Connection&)> errorHandler;
287     std::shared_ptr<persistent_data::UserSession> session;
288 };
289 } // namespace websocket
290 } // namespace crow
291