xref: /openbmc/bmcweb/http/websocket.hpp (revision 81d523a7)
1 #pragma once
2 #include "http_request.hpp"
3 
4 #include <async_resp.hpp>
5 #include <boost/algorithm/string/predicate.hpp>
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 struct Connection : std::enable_shared_from_this<Connection>
22 {
23   public:
24     explicit Connection(const crow::Request& reqIn) :
25         req(reqIn.req), userdataPtr(nullptr)
26     {}
27 
28     explicit Connection(const crow::Request& reqIn, std::string user) :
29         req(reqIn.req), userName{std::move(user)}, userdataPtr(nullptr)
30     {}
31 
32     Connection(const Connection&) = delete;
33     Connection(Connection&&) = delete;
34     Connection& operator=(const Connection&) = delete;
35     Connection& operator=(const Connection&&) = delete;
36 
37     virtual void sendBinary(std::string_view msg) = 0;
38     virtual void sendBinary(std::string&& msg) = 0;
39     virtual void sendText(std::string_view msg) = 0;
40     virtual void sendText(std::string&& msg) = 0;
41     virtual void close(std::string_view msg = "quit") = 0;
42     virtual boost::asio::io_context& getIoContext() = 0;
43     virtual ~Connection() = default;
44 
45     void userdata(void* u)
46     {
47         userdataPtr = u;
48     }
49     void* userdata()
50     {
51         return userdataPtr;
52     }
53 
54     const std::string& getUserName() const
55     {
56         return userName;
57     }
58 
59     boost::beast::http::request<boost::beast::http::string_body> req;
60     crow::Response res;
61 
62   private:
63     std::string userName{};
64     void* userdataPtr;
65 };
66 
67 template <typename Adaptor>
68 class ConnectionImpl : public Connection
69 {
70   public:
71     ConnectionImpl(
72         const crow::Request& reqIn, Adaptor adaptorIn,
73         std::function<void(Connection&)> openHandler,
74         std::function<void(Connection&, const std::string&, bool)>
75             messageHandler,
76         std::function<void(Connection&, const std::string&)> closeHandler,
77         std::function<void(Connection&)> errorHandler) :
78         Connection(reqIn, reqIn.session == nullptr ? std::string{}
79                                                    : reqIn.session->username),
80         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
81         openHandler(std::move(openHandler)),
82         messageHandler(std::move(messageHandler)),
83         closeHandler(std::move(closeHandler)),
84         errorHandler(std::move(errorHandler)), session(reqIn.session)
85     {
86         /* Turn on the timeouts on websocket stream to server role */
87         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
88             boost::beast::role_type::server));
89         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
90     }
91 
92     boost::asio::io_context& getIoContext() override
93     {
94         return static_cast<boost::asio::io_context&>(
95             ws.get_executor().context());
96     }
97 
98     void start()
99     {
100         BMCWEB_LOG_DEBUG << "starting connection " << this;
101 
102         using bf = boost::beast::http::field;
103 
104         std::string_view protocol = req[bf::sec_websocket_protocol];
105 
106         ws.set_option(boost::beast::websocket::stream_base::decorator(
107             [session{session}, protocol{std::string(protocol)}](
108                 boost::beast::websocket::response_type& m) {
109 
110 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
111             if (session != nullptr)
112             {
113                 // use protocol for csrf checking
114                 if (session->cookieAuth &&
115                     !crow::utility::constantTimeStringCompare(
116                         protocol, session->csrfToken))
117                 {
118                     BMCWEB_LOG_ERROR << "Websocket CSRF error";
119                     m.result(boost::beast::http::status::unauthorized);
120                     return;
121                 }
122             }
123 #endif
124             if (!protocol.empty())
125             {
126                 m.insert(bf::sec_websocket_protocol, protocol);
127             }
128 
129             m.insert(bf::strict_transport_security, "max-age=31536000; "
130                                                     "includeSubdomains; "
131                                                     "preload");
132             m.insert(bf::pragma, "no-cache");
133             m.insert(bf::cache_control, "no-Store,no-Cache");
134             m.insert("Content-Security-Policy", "default-src 'self'");
135             m.insert("X-XSS-Protection", "1; "
136                                          "mode=block");
137             m.insert("X-Content-Type-Options", "nosniff");
138         }));
139 
140         // Perform the websocket upgrade
141         ws.async_accept(req, [this, self(shared_from_this())](
142                                  boost::system::error_code ec) {
143             if (ec)
144             {
145                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
146                 return;
147             }
148             acceptDone();
149         });
150     }
151 
152     void sendBinary(const std::string_view msg) override
153     {
154         ws.binary(true);
155         outBuffer.emplace_back(msg);
156         doWrite();
157     }
158 
159     void sendBinary(std::string&& msg) override
160     {
161         ws.binary(true);
162         outBuffer.emplace_back(std::move(msg));
163         doWrite();
164     }
165 
166     void sendText(const std::string_view msg) override
167     {
168         ws.text(true);
169         outBuffer.emplace_back(msg);
170         doWrite();
171     }
172 
173     void sendText(std::string&& msg) override
174     {
175         ws.text(true);
176         outBuffer.emplace_back(std::move(msg));
177         doWrite();
178     }
179 
180     void close(const std::string_view msg) override
181     {
182         ws.async_close(
183             {boost::beast::websocket::close_code::normal, msg},
184             [self(shared_from_this())](boost::system::error_code ec) {
185             if (ec == boost::asio::error::operation_aborted)
186             {
187                 return;
188             }
189             if (ec)
190             {
191                 BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
192                 return;
193             }
194             });
195     }
196 
197     void acceptDone()
198     {
199         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
200 
201         doRead();
202 
203         if (openHandler)
204         {
205             openHandler(*this);
206         }
207     }
208 
209     void doRead()
210     {
211         ws.async_read(inBuffer,
212                       [this, self(shared_from_this())](
213                           boost::beast::error_code ec, std::size_t bytesRead) {
214             if (ec)
215             {
216                 if (ec != boost::beast::websocket::error::closed)
217                 {
218                     BMCWEB_LOG_ERROR << "doRead error " << ec;
219                 }
220                 if (closeHandler)
221                 {
222                     std::string_view reason = ws.reason().reason;
223                     closeHandler(*this, std::string(reason));
224                 }
225                 return;
226             }
227             if (messageHandler)
228             {
229                 messageHandler(*this, inString, ws.got_text());
230             }
231             inBuffer.consume(bytesRead);
232             inString.clear();
233             doRead();
234         });
235     }
236 
237     void doWrite()
238     {
239         // If we're already doing a write, ignore the request, it will be picked
240         // up when the current write is complete
241         if (doingWrite)
242         {
243             return;
244         }
245 
246         if (outBuffer.empty())
247         {
248             // Done for now
249             return;
250         }
251         doingWrite = true;
252         ws.async_write(boost::asio::buffer(outBuffer.front()),
253                        [this, self(shared_from_this())](
254                            boost::beast::error_code ec, std::size_t) {
255             doingWrite = false;
256             outBuffer.erase(outBuffer.begin());
257             if (ec == boost::beast::websocket::error::closed)
258             {
259                 // Do nothing here.  doRead handler will call the
260                 // closeHandler.
261                 close("Write error");
262                 return;
263             }
264             if (ec)
265             {
266                 BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
267                 return;
268             }
269             doWrite();
270         });
271     }
272 
273   private:
274     boost::beast::websocket::stream<Adaptor, false> ws;
275 
276     std::string inString;
277     boost::asio::dynamic_string_buffer<std::string::value_type,
278                                        std::string::traits_type,
279                                        std::string::allocator_type>
280         inBuffer;
281     std::vector<std::string> outBuffer;
282     bool doingWrite = false;
283 
284     std::function<void(Connection&)> openHandler;
285     std::function<void(Connection&, const std::string&, bool)> messageHandler;
286     std::function<void(Connection&, const std::string&)> closeHandler;
287     std::function<void(Connection&)> errorHandler;
288     std::shared_ptr<persistent_data::UserSession> session;
289 };
290 } // namespace websocket
291 } // namespace crow
292