xref: /openbmc/bmcweb/http/websocket.hpp (revision 9d424669)
1 #pragma once
2 #include "http_request.hpp"
3 
4 #include <async_resp.hpp>
5 #include <boost/algorithm/string/predicate.hpp>
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 struct Connection : std::enable_shared_from_this<Connection>
22 {
23   public:
24     explicit Connection(const crow::Request& reqIn) :
25         req(reqIn.req), userdataPtr(nullptr)
26     {}
27 
28     explicit Connection(const crow::Request& reqIn, std::string user) :
29         req(reqIn.req), userName{std::move(user)}, userdataPtr(nullptr)
30     {}
31 
32     virtual void sendBinary(const std::string_view msg) = 0;
33     virtual void sendBinary(std::string&& msg) = 0;
34     virtual void sendText(const std::string_view msg) = 0;
35     virtual void sendText(std::string&& msg) = 0;
36     virtual void close(const std::string_view msg = "quit") = 0;
37     virtual boost::asio::io_context& getIoContext() = 0;
38     virtual ~Connection() = default;
39 
40     void userdata(void* u)
41     {
42         userdataPtr = u;
43     }
44     void* userdata()
45     {
46         return userdataPtr;
47     }
48 
49     const std::string& getUserName() const
50     {
51         return userName;
52     }
53 
54     boost::beast::http::request<boost::beast::http::string_body> req;
55     crow::Response res;
56 
57   private:
58     std::string userName{};
59     void* userdataPtr;
60 };
61 
62 template <typename Adaptor>
63 class ConnectionImpl : public Connection
64 {
65   public:
66     ConnectionImpl(
67         const crow::Request& reqIn, Adaptor adaptorIn,
68         std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
69             openHandler,
70         std::function<void(Connection&, const std::string&, bool)>
71             messageHandler,
72         std::function<void(Connection&, const std::string&)> closeHandler,
73         std::function<void(Connection&)> errorHandler) :
74         Connection(reqIn, reqIn.session->username),
75         ws(std::move(adaptorIn)), inString(), inBuffer(inString, 131088),
76         openHandler(std::move(openHandler)),
77         messageHandler(std::move(messageHandler)),
78         closeHandler(std::move(closeHandler)),
79         errorHandler(std::move(errorHandler)), session(reqIn.session)
80     {
81         /* Turn on the timeouts on websocket stream to server role */
82         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
83             boost::beast::role_type::server));
84         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
85     }
86 
87     boost::asio::io_context& getIoContext() override
88     {
89         return static_cast<boost::asio::io_context&>(
90             ws.get_executor().context());
91     }
92 
93     void start()
94     {
95         BMCWEB_LOG_DEBUG << "starting connection " << this;
96 
97         using bf = boost::beast::http::field;
98 
99         std::string_view protocol = req[bf::sec_websocket_protocol];
100 
101         ws.set_option(boost::beast::websocket::stream_base::decorator(
102             [session{session}, protocol{std::string(protocol)}](
103                 boost::beast::websocket::response_type& m) {
104 
105 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
106                 if (session != nullptr)
107                 {
108                     // use protocol for csrf checking
109                     if (session->cookieAuth &&
110                         !crow::utility::constantTimeStringCompare(
111                             protocol, session->csrfToken))
112                     {
113                         BMCWEB_LOG_ERROR << "Websocket CSRF error";
114                         m.result(boost::beast::http::status::unauthorized);
115                         return;
116                     }
117                 }
118 #endif
119                 if (!protocol.empty())
120                 {
121                     m.insert(bf::sec_websocket_protocol, protocol);
122                 }
123 
124                 m.insert(bf::strict_transport_security, "max-age=31536000; "
125                                                         "includeSubdomains; "
126                                                         "preload");
127                 m.insert(bf::pragma, "no-cache");
128                 m.insert(bf::cache_control, "no-Store,no-Cache");
129                 m.insert("Content-Security-Policy", "default-src 'self'");
130                 m.insert("X-XSS-Protection", "1; "
131                                              "mode=block");
132                 m.insert("X-Content-Type-Options", "nosniff");
133             }));
134 
135         // Perform the websocket upgrade
136         ws.async_accept(req, [this, self(shared_from_this())](
137                                  boost::system::error_code ec) {
138             if (ec)
139             {
140                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
141                 return;
142             }
143             acceptDone();
144         });
145     }
146 
147     void sendBinary(const std::string_view msg) override
148     {
149         ws.binary(true);
150         outBuffer.emplace_back(msg);
151         doWrite();
152     }
153 
154     void sendBinary(std::string&& msg) override
155     {
156         ws.binary(true);
157         outBuffer.emplace_back(std::move(msg));
158         doWrite();
159     }
160 
161     void sendText(const std::string_view msg) override
162     {
163         ws.text(true);
164         outBuffer.emplace_back(msg);
165         doWrite();
166     }
167 
168     void sendText(std::string&& msg) override
169     {
170         ws.text(true);
171         outBuffer.emplace_back(std::move(msg));
172         doWrite();
173     }
174 
175     void close(const std::string_view msg) override
176     {
177         ws.async_close(
178             {boost::beast::websocket::close_code::normal, msg},
179             [self(shared_from_this())](boost::system::error_code ec) {
180                 if (ec == boost::asio::error::operation_aborted)
181                 {
182                     return;
183                 }
184                 if (ec)
185                 {
186                     BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
187                     return;
188                 }
189             });
190     }
191 
192     void acceptDone()
193     {
194         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
195 
196         auto asyncResp = std::make_shared<bmcweb::AsyncResp>(
197             res, [this, self(shared_from_this())]() { doRead(); });
198 
199         asyncResp->res.result(boost::beast::http::status::ok);
200 
201         if (openHandler)
202         {
203             openHandler(*this, asyncResp);
204         }
205     }
206 
207     void doRead()
208     {
209         ws.async_read(inBuffer,
210                       [this, self(shared_from_this())](
211                           boost::beast::error_code ec, std::size_t bytesRead) {
212                           if (ec)
213                           {
214                               if (ec != boost::beast::websocket::error::closed)
215                               {
216                                   BMCWEB_LOG_ERROR << "doRead error " << ec;
217                               }
218                               if (closeHandler)
219                               {
220                                   std::string_view reason = ws.reason().reason;
221                                   closeHandler(*this, std::string(reason));
222                               }
223                               return;
224                           }
225                           if (messageHandler)
226                           {
227                               messageHandler(*this, inString, ws.got_text());
228                           }
229                           inBuffer.consume(bytesRead);
230                           inString.clear();
231                           doRead();
232                       });
233     }
234 
235     void doWrite()
236     {
237         // If we're already doing a write, ignore the request, it will be picked
238         // up when the current write is complete
239         if (doingWrite)
240         {
241             return;
242         }
243 
244         if (outBuffer.empty())
245         {
246             // Done for now
247             return;
248         }
249         doingWrite = true;
250         ws.async_write(boost::asio::buffer(outBuffer.front()),
251                        [this, self(shared_from_this())](
252                            boost::beast::error_code ec, std::size_t) {
253                            doingWrite = false;
254                            outBuffer.erase(outBuffer.begin());
255                            if (ec == boost::beast::websocket::error::closed)
256                            {
257                                // Do nothing here.  doRead handler will call the
258                                // closeHandler.
259                                close("Write error");
260                                return;
261                            }
262                            if (ec)
263                            {
264                                BMCWEB_LOG_ERROR << "Error in ws.async_write "
265                                                 << ec;
266                                return;
267                            }
268                            doWrite();
269                        });
270     }
271 
272   private:
273     boost::beast::websocket::stream<Adaptor, false> ws;
274 
275     std::string inString;
276     boost::asio::dynamic_string_buffer<std::string::value_type,
277                                        std::string::traits_type,
278                                        std::string::allocator_type>
279         inBuffer;
280     std::vector<std::string> outBuffer;
281     bool doingWrite = false;
282 
283     std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
284         openHandler;
285     std::function<void(Connection&, const std::string&, bool)> messageHandler;
286     std::function<void(Connection&, const std::string&)> closeHandler;
287     std::function<void(Connection&)> errorHandler;
288     std::shared_ptr<persistent_data::UserSession> session;
289 };
290 } // namespace websocket
291 } // namespace crow
292