xref: /openbmc/bmcweb/http/websocket.hpp (revision 141d9431)
1 #pragma once
2 #include "http_request.hpp"
3 
4 #include <async_resp.hpp>
5 #include <boost/algorithm/string/predicate.hpp>
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 struct Connection : std::enable_shared_from_this<Connection>
22 {
23   public:
24     explicit Connection(const crow::Request& reqIn) :
25         req(reqIn.req), userdataPtr(nullptr)
26     {}
27 
28     explicit Connection(const crow::Request& reqIn, std::string user) :
29         req(reqIn.req), userName{std::move(user)}, userdataPtr(nullptr)
30     {}
31 
32     Connection(const Connection&) = delete;
33     Connection(Connection&&) = delete;
34     Connection& operator=(const Connection&) = delete;
35     Connection& operator=(const Connection&&) = delete;
36 
37     virtual void sendBinary(std::string_view msg) = 0;
38     virtual void sendBinary(std::string&& msg) = 0;
39     virtual void sendText(std::string_view msg) = 0;
40     virtual void sendText(std::string&& msg) = 0;
41     virtual void close(std::string_view msg = "quit") = 0;
42     virtual boost::asio::io_context& getIoContext() = 0;
43     virtual ~Connection() = default;
44 
45     void userdata(void* u)
46     {
47         userdataPtr = u;
48     }
49     void* userdata()
50     {
51         return userdataPtr;
52     }
53 
54     const std::string& getUserName() const
55     {
56         return userName;
57     }
58 
59     boost::beast::http::request<boost::beast::http::string_body> req;
60     crow::Response res;
61 
62   private:
63     std::string userName{};
64     void* userdataPtr;
65 };
66 
67 template <typename Adaptor>
68 class ConnectionImpl : public Connection
69 {
70   public:
71     ConnectionImpl(
72         const crow::Request& reqIn, Adaptor adaptorIn,
73         std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
74             openHandler,
75         std::function<void(Connection&, const std::string&, bool)>
76             messageHandler,
77         std::function<void(Connection&, const std::string&)> closeHandler,
78         std::function<void(Connection&)> errorHandler) :
79         Connection(reqIn, reqIn.session->username),
80         ws(std::move(adaptorIn)), inString(), inBuffer(inString, 131088),
81         openHandler(std::move(openHandler)),
82         messageHandler(std::move(messageHandler)),
83         closeHandler(std::move(closeHandler)),
84         errorHandler(std::move(errorHandler)), session(reqIn.session)
85     {
86         /* Turn on the timeouts on websocket stream to server role */
87         ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
88             boost::beast::role_type::server));
89         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
90     }
91 
92     boost::asio::io_context& getIoContext() override
93     {
94         return static_cast<boost::asio::io_context&>(
95             ws.get_executor().context());
96     }
97 
98     void start()
99     {
100         BMCWEB_LOG_DEBUG << "starting connection " << this;
101 
102         using bf = boost::beast::http::field;
103 
104         std::string_view protocol = req[bf::sec_websocket_protocol];
105 
106         ws.set_option(boost::beast::websocket::stream_base::decorator(
107             [session{session}, protocol{std::string(protocol)}](
108                 boost::beast::websocket::response_type& m) {
109 
110 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
111                 if (session != nullptr)
112                 {
113                     // use protocol for csrf checking
114                     if (session->cookieAuth &&
115                         !crow::utility::constantTimeStringCompare(
116                             protocol, session->csrfToken))
117                     {
118                         BMCWEB_LOG_ERROR << "Websocket CSRF error";
119                         m.result(boost::beast::http::status::unauthorized);
120                         return;
121                     }
122                 }
123 #endif
124                 if (!protocol.empty())
125                 {
126                     m.insert(bf::sec_websocket_protocol, protocol);
127                 }
128 
129                 m.insert(bf::strict_transport_security, "max-age=31536000; "
130                                                         "includeSubdomains; "
131                                                         "preload");
132                 m.insert(bf::pragma, "no-cache");
133                 m.insert(bf::cache_control, "no-Store,no-Cache");
134                 m.insert("Content-Security-Policy", "default-src 'self'");
135                 m.insert("X-XSS-Protection", "1; "
136                                              "mode=block");
137                 m.insert("X-Content-Type-Options", "nosniff");
138             }));
139 
140         // Perform the websocket upgrade
141         ws.async_accept(req, [this, self(shared_from_this())](
142                                  boost::system::error_code ec) {
143             if (ec)
144             {
145                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
146                 return;
147             }
148             acceptDone();
149         });
150     }
151 
152     void sendBinary(const std::string_view msg) override
153     {
154         ws.binary(true);
155         outBuffer.emplace_back(msg);
156         doWrite();
157     }
158 
159     void sendBinary(std::string&& msg) override
160     {
161         ws.binary(true);
162         outBuffer.emplace_back(std::move(msg));
163         doWrite();
164     }
165 
166     void sendText(const std::string_view msg) override
167     {
168         ws.text(true);
169         outBuffer.emplace_back(msg);
170         doWrite();
171     }
172 
173     void sendText(std::string&& msg) override
174     {
175         ws.text(true);
176         outBuffer.emplace_back(std::move(msg));
177         doWrite();
178     }
179 
180     void close(const std::string_view msg) override
181     {
182         ws.async_close(
183             {boost::beast::websocket::close_code::normal, msg},
184             [self(shared_from_this())](boost::system::error_code ec) {
185                 if (ec == boost::asio::error::operation_aborted)
186                 {
187                     return;
188                 }
189                 if (ec)
190                 {
191                     BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
192                     return;
193                 }
194             });
195     }
196 
197     void acceptDone()
198     {
199         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
200 
201         auto asyncResp = std::make_shared<bmcweb::AsyncResp>(
202             res, [this, self(shared_from_this())]() { doRead(); });
203 
204         asyncResp->res.result(boost::beast::http::status::ok);
205 
206         if (openHandler)
207         {
208             openHandler(*this, asyncResp);
209         }
210     }
211 
212     void doRead()
213     {
214         ws.async_read(inBuffer,
215                       [this, self(shared_from_this())](
216                           boost::beast::error_code ec, std::size_t bytesRead) {
217                           if (ec)
218                           {
219                               if (ec != boost::beast::websocket::error::closed)
220                               {
221                                   BMCWEB_LOG_ERROR << "doRead error " << ec;
222                               }
223                               if (closeHandler)
224                               {
225                                   std::string_view reason = ws.reason().reason;
226                                   closeHandler(*this, std::string(reason));
227                               }
228                               return;
229                           }
230                           if (messageHandler)
231                           {
232                               messageHandler(*this, inString, ws.got_text());
233                           }
234                           inBuffer.consume(bytesRead);
235                           inString.clear();
236                           doRead();
237                       });
238     }
239 
240     void doWrite()
241     {
242         // If we're already doing a write, ignore the request, it will be picked
243         // up when the current write is complete
244         if (doingWrite)
245         {
246             return;
247         }
248 
249         if (outBuffer.empty())
250         {
251             // Done for now
252             return;
253         }
254         doingWrite = true;
255         ws.async_write(boost::asio::buffer(outBuffer.front()),
256                        [this, self(shared_from_this())](
257                            boost::beast::error_code ec, std::size_t) {
258                            doingWrite = false;
259                            outBuffer.erase(outBuffer.begin());
260                            if (ec == boost::beast::websocket::error::closed)
261                            {
262                                // Do nothing here.  doRead handler will call the
263                                // closeHandler.
264                                close("Write error");
265                                return;
266                            }
267                            if (ec)
268                            {
269                                BMCWEB_LOG_ERROR << "Error in ws.async_write "
270                                                 << ec;
271                                return;
272                            }
273                            doWrite();
274                        });
275     }
276 
277   private:
278     boost::beast::websocket::stream<Adaptor, false> ws;
279 
280     std::string inString;
281     boost::asio::dynamic_string_buffer<std::string::value_type,
282                                        std::string::traits_type,
283                                        std::string::allocator_type>
284         inBuffer;
285     std::vector<std::string> outBuffer;
286     bool doingWrite = false;
287 
288     std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
289         openHandler;
290     std::function<void(Connection&, const std::string&, bool)> messageHandler;
291     std::function<void(Connection&, const std::string&)> closeHandler;
292     std::function<void(Connection&)> errorHandler;
293     std::shared_ptr<persistent_data::UserSession> session;
294 };
295 } // namespace websocket
296 } // namespace crow
297