xref: /openbmc/bmcweb/http/websocket.hpp (revision e551b5fa48a78c06441456106e35d4b2ac8642b6)
1 #pragma once
2 #include "async_resp.hpp"
3 #include "http_request.hpp"
4 
5 #include <boost/asio/buffer.hpp>
6 #include <boost/beast/websocket.hpp>
7 
8 #include <array>
9 #include <functional>
10 
11 #ifdef BMCWEB_ENABLE_SSL
12 #include <boost/beast/websocket/ssl.hpp>
13 #endif
14 
15 namespace crow
16 {
17 namespace websocket
18 {
19 
20 struct Connection : std::enable_shared_from_this<Connection>
21 {
22   public:
23     explicit Connection(const crow::Request& reqIn) : req(reqIn.req)
24     {}
25 
26     Connection(const Connection&) = delete;
27     Connection(Connection&&) = delete;
28     Connection& operator=(const Connection&) = delete;
29     Connection& operator=(const Connection&&) = delete;
30 
31     virtual void sendBinary(std::string_view msg) = 0;
32     virtual void sendBinary(std::string&& msg) = 0;
33     virtual void sendText(std::string_view msg) = 0;
34     virtual void sendText(std::string&& msg) = 0;
35     virtual void close(std::string_view msg = "quit") = 0;
36     virtual boost::asio::io_context& getIoContext() = 0;
37     virtual ~Connection() = default;
38 
39     boost::beast::http::request<boost::beast::http::string_body> req;
40 };
41 
42 template <typename Adaptor>
43 class ConnectionImpl : public Connection
44 {
45   public:
46     ConnectionImpl(
47         const crow::Request& reqIn, Adaptor adaptorIn,
48         std::function<void(Connection&)> openHandlerIn,
49         std::function<void(Connection&, const std::string&, bool)>
50             messageHandlerIn,
51         std::function<void(Connection&, const std::string&)> closeHandlerIn,
52         std::function<void(Connection&)> errorHandlerIn) :
53         Connection(reqIn),
54         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
55         openHandler(std::move(openHandlerIn)),
56         messageHandler(std::move(messageHandlerIn)),
57         closeHandler(std::move(closeHandlerIn)),
58         errorHandler(std::move(errorHandlerIn)), session(reqIn.session)
59     {
60         /* Turn on the timeouts on websocket stream to server role */
61         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
62             boost::beast::role_type::server));
63         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
64     }
65 
66     boost::asio::io_context& getIoContext() override
67     {
68         return static_cast<boost::asio::io_context&>(
69             ws.get_executor().context());
70     }
71 
72     void start()
73     {
74         BMCWEB_LOG_DEBUG << "starting connection " << this;
75 
76         using bf = boost::beast::http::field;
77 
78         std::string_view protocol = req[bf::sec_websocket_protocol];
79 
80         ws.set_option(boost::beast::websocket::stream_base::decorator(
81             [session{session}, protocol{std::string(protocol)}](
82                 boost::beast::websocket::response_type& m) {
83 
84 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
85             if (session != nullptr)
86             {
87                 // use protocol for csrf checking
88                 if (session->cookieAuth &&
89                     !crow::utility::constantTimeStringCompare(
90                         protocol, session->csrfToken))
91                 {
92                     BMCWEB_LOG_ERROR << "Websocket CSRF error";
93                     m.result(boost::beast::http::status::unauthorized);
94                     return;
95                 }
96             }
97 #endif
98             if (!protocol.empty())
99             {
100                 m.insert(bf::sec_websocket_protocol, protocol);
101             }
102 
103             m.insert(bf::strict_transport_security, "max-age=31536000; "
104                                                     "includeSubdomains; "
105                                                     "preload");
106             m.insert(bf::pragma, "no-cache");
107             m.insert(bf::cache_control, "no-Store,no-Cache");
108             m.insert("Content-Security-Policy", "default-src 'self'");
109             m.insert("X-XSS-Protection", "1; "
110                                          "mode=block");
111             m.insert("X-Content-Type-Options", "nosniff");
112         }));
113 
114         // Perform the websocket upgrade
115         ws.async_accept(req, [this, self(shared_from_this())](
116                                  const boost::system::error_code& ec) {
117             if (ec)
118             {
119                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
120                 return;
121             }
122             acceptDone();
123         });
124     }
125 
126     void sendBinary(std::string_view msg) override
127     {
128         ws.binary(true);
129         outBuffer.emplace_back(msg);
130         doWrite();
131     }
132 
133     void sendBinary(std::string&& msg) override
134     {
135         ws.binary(true);
136         outBuffer.emplace_back(std::move(msg));
137         doWrite();
138     }
139 
140     void sendText(std::string_view msg) override
141     {
142         ws.text(true);
143         outBuffer.emplace_back(msg);
144         doWrite();
145     }
146 
147     void sendText(std::string&& msg) override
148     {
149         ws.text(true);
150         outBuffer.emplace_back(std::move(msg));
151         doWrite();
152     }
153 
154     void close(std::string_view msg) override
155     {
156         ws.async_close(
157             {boost::beast::websocket::close_code::normal, msg},
158             [self(shared_from_this())](const boost::system::error_code& ec) {
159             if (ec == boost::asio::error::operation_aborted)
160             {
161                 return;
162             }
163             if (ec)
164             {
165                 BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
166                 return;
167             }
168             });
169     }
170 
171     void acceptDone()
172     {
173         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
174 
175         doRead();
176 
177         if (openHandler)
178         {
179             openHandler(*this);
180         }
181     }
182 
183     void doRead()
184     {
185         ws.async_read(inBuffer,
186                       [this, self(shared_from_this())](
187                           boost::beast::error_code ec, std::size_t bytesRead) {
188             if (ec)
189             {
190                 if (ec != boost::beast::websocket::error::closed)
191                 {
192                     BMCWEB_LOG_ERROR << "doRead error " << ec;
193                 }
194                 if (closeHandler)
195                 {
196                     std::string reason{ws.reason().reason.c_str()};
197                     closeHandler(*this, reason);
198                 }
199                 return;
200             }
201             if (messageHandler)
202             {
203                 messageHandler(*this, inString, ws.got_text());
204             }
205             inBuffer.consume(bytesRead);
206             inString.clear();
207             doRead();
208         });
209     }
210 
211     void doWrite()
212     {
213         // If we're already doing a write, ignore the request, it will be picked
214         // up when the current write is complete
215         if (doingWrite)
216         {
217             return;
218         }
219 
220         if (outBuffer.empty())
221         {
222             // Done for now
223             return;
224         }
225         doingWrite = true;
226         ws.async_write(boost::asio::buffer(outBuffer.front()),
227                        [this, self(shared_from_this())](
228                            boost::beast::error_code ec, std::size_t) {
229             doingWrite = false;
230             outBuffer.erase(outBuffer.begin());
231             if (ec == boost::beast::websocket::error::closed)
232             {
233                 // Do nothing here.  doRead handler will call the
234                 // closeHandler.
235                 close("Write error");
236                 return;
237             }
238             if (ec)
239             {
240                 BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
241                 return;
242             }
243             doWrite();
244         });
245     }
246 
247   private:
248     boost::beast::websocket::stream<Adaptor, false> ws;
249 
250     std::string inString;
251     boost::asio::dynamic_string_buffer<std::string::value_type,
252                                        std::string::traits_type,
253                                        std::string::allocator_type>
254         inBuffer;
255     std::vector<std::string> outBuffer;
256     bool doingWrite = false;
257 
258     std::function<void(Connection&)> openHandler;
259     std::function<void(Connection&, const std::string&, bool)> messageHandler;
260     std::function<void(Connection&, const std::string&)> closeHandler;
261     std::function<void(Connection&)> errorHandler;
262     std::shared_ptr<persistent_data::UserSession> session;
263 };
264 } // namespace websocket
265 } // namespace crow
266