xref: /openbmc/bmcweb/http/websocket.hpp (revision 04e438cbad66838724d78ce12f28aff1fb892a63)
1 #pragma once
2 #include "http_request.hpp"
3 
4 #include <async_resp.hpp>
5 #include <boost/algorithm/string/predicate.hpp>
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 struct Connection : std::enable_shared_from_this<Connection>
22 {
23   public:
24     explicit Connection(const crow::Request& reqIn) :
25         req(reqIn.req), userdataPtr(nullptr)
26     {}
27 
28     explicit Connection(const crow::Request& reqIn, std::string user) :
29         req(reqIn.req), userName{std::move(user)}, userdataPtr(nullptr)
30     {}
31 
32     virtual void sendBinary(const std::string_view msg) = 0;
33     virtual void sendBinary(std::string&& msg) = 0;
34     virtual void sendText(const std::string_view msg) = 0;
35     virtual void sendText(std::string&& msg) = 0;
36     virtual void close(const std::string_view msg = "quit") = 0;
37     virtual boost::asio::io_context& getIoContext() = 0;
38     virtual ~Connection() = default;
39 
40     void userdata(void* u)
41     {
42         userdataPtr = u;
43     }
44     void* userdata()
45     {
46         return userdataPtr;
47     }
48 
49     const std::string& getUserName() const
50     {
51         return userName;
52     }
53 
54     boost::beast::http::request<boost::beast::http::string_body> req;
55     crow::Response res;
56 
57   private:
58     std::string userName{};
59     void* userdataPtr;
60 };
61 
62 template <typename Adaptor>
63 class ConnectionImpl : public Connection
64 {
65   public:
66     ConnectionImpl(
67         const crow::Request& reqIn, Adaptor adaptorIn,
68         std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
69             open_handler,
70         std::function<void(Connection&, const std::string&, bool)>
71             message_handler,
72         std::function<void(Connection&, const std::string&)> close_handler,
73         std::function<void(Connection&)> error_handler) :
74         Connection(reqIn, reqIn.session->username),
75         ws(std::move(adaptorIn)), inString(), inBuffer(inString, 131088),
76         openHandler(std::move(open_handler)),
77         messageHandler(std::move(message_handler)),
78         closeHandler(std::move(close_handler)),
79         errorHandler(std::move(error_handler)), session(reqIn.session)
80     {
81         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
82     }
83 
84     boost::asio::io_context& getIoContext() override
85     {
86         return static_cast<boost::asio::io_context&>(
87             ws.get_executor().context());
88     }
89 
90     void start()
91     {
92         BMCWEB_LOG_DEBUG << "starting connection " << this;
93 
94         using bf = boost::beast::http::field;
95 
96         std::string_view protocol = req[bf::sec_websocket_protocol];
97 
98         ws.set_option(boost::beast::websocket::stream_base::decorator(
99             [session{session}, protocol{std::string(protocol)}](
100                 boost::beast::websocket::response_type& m) {
101 
102 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
103                 // use protocol for csrf checking
104                 if (session->cookieAuth &&
105                     !crow::utility::constantTimeStringCompare(
106                         protocol, session->csrfToken))
107                 {
108                     BMCWEB_LOG_ERROR << "Websocket CSRF error";
109                     m.result(boost::beast::http::status::unauthorized);
110                     return;
111                 }
112 #endif
113                 if (!protocol.empty())
114                 {
115                     m.insert(bf::sec_websocket_protocol, protocol);
116                 }
117 
118                 m.insert(bf::strict_transport_security, "max-age=31536000; "
119                                                         "includeSubdomains; "
120                                                         "preload");
121                 m.insert(bf::pragma, "no-cache");
122                 m.insert(bf::cache_control, "no-Store,no-Cache");
123                 m.insert("Content-Security-Policy", "default-src 'self'");
124                 m.insert("X-XSS-Protection", "1; "
125                                              "mode=block");
126                 m.insert("X-Content-Type-Options", "nosniff");
127             }));
128 
129         // Perform the websocket upgrade
130         ws.async_accept(req, [this, self(shared_from_this())](
131                                  boost::system::error_code ec) {
132             if (ec)
133             {
134                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
135                 return;
136             }
137             acceptDone();
138         });
139     }
140 
141     void sendBinary(const std::string_view msg) override
142     {
143         ws.binary(true);
144         outBuffer.emplace_back(msg);
145         doWrite();
146     }
147 
148     void sendBinary(std::string&& msg) override
149     {
150         ws.binary(true);
151         outBuffer.emplace_back(std::move(msg));
152         doWrite();
153     }
154 
155     void sendText(const std::string_view msg) override
156     {
157         ws.text(true);
158         outBuffer.emplace_back(msg);
159         doWrite();
160     }
161 
162     void sendText(std::string&& msg) override
163     {
164         ws.text(true);
165         outBuffer.emplace_back(std::move(msg));
166         doWrite();
167     }
168 
169     void close(const std::string_view msg) override
170     {
171         ws.async_close(
172             {boost::beast::websocket::close_code::normal, msg},
173             [self(shared_from_this())](boost::system::error_code ec) {
174                 if (ec == boost::asio::error::operation_aborted)
175                 {
176                     return;
177                 }
178                 if (ec)
179                 {
180                     BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
181                     return;
182                 }
183             });
184     }
185 
186     void acceptDone()
187     {
188         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
189 
190         auto asyncResp = std::make_shared<bmcweb::AsyncResp>(
191             res, [this, self(shared_from_this())]() { doRead(); });
192 
193         asyncResp->res.result(boost::beast::http::status::ok);
194 
195         if (openHandler)
196         {
197             openHandler(*this, asyncResp);
198         }
199     }
200 
201     void doRead()
202     {
203         ws.async_read(inBuffer,
204                       [this, self(shared_from_this())](
205                           boost::beast::error_code ec, std::size_t bytes_read) {
206                           if (ec)
207                           {
208                               if (ec != boost::beast::websocket::error::closed)
209                               {
210                                   BMCWEB_LOG_ERROR << "doRead error " << ec;
211                               }
212                               if (closeHandler)
213                               {
214                                   std::string_view reason = ws.reason().reason;
215                                   closeHandler(*this, std::string(reason));
216                               }
217                               return;
218                           }
219                           if (messageHandler)
220                           {
221                               messageHandler(*this, inString, ws.got_text());
222                           }
223                           inBuffer.consume(bytes_read);
224                           inString.clear();
225                           doRead();
226                       });
227     }
228 
229     void doWrite()
230     {
231         // If we're already doing a write, ignore the request, it will be picked
232         // up when the current write is complete
233         if (doingWrite)
234         {
235             return;
236         }
237 
238         if (outBuffer.empty())
239         {
240             // Done for now
241             return;
242         }
243         doingWrite = true;
244         ws.async_write(boost::asio::buffer(outBuffer.front()),
245                        [this, self(shared_from_this())](
246                            boost::beast::error_code ec, std::size_t) {
247                            doingWrite = false;
248                            outBuffer.erase(outBuffer.begin());
249                            if (ec == boost::beast::websocket::error::closed)
250                            {
251                                // Do nothing here.  doRead handler will call the
252                                // closeHandler.
253                                close("Write error");
254                                return;
255                            }
256                            if (ec)
257                            {
258                                BMCWEB_LOG_ERROR << "Error in ws.async_write "
259                                                 << ec;
260                                return;
261                            }
262                            doWrite();
263                        });
264     }
265 
266   private:
267     boost::beast::websocket::stream<Adaptor> ws;
268 
269     std::string inString;
270     boost::asio::dynamic_string_buffer<std::string::value_type,
271                                        std::string::traits_type,
272                                        std::string::allocator_type>
273         inBuffer;
274     std::vector<std::string> outBuffer;
275     bool doingWrite = false;
276 
277     std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
278         openHandler;
279     std::function<void(Connection&, const std::string&, bool)> messageHandler;
280     std::function<void(Connection&, const std::string&)> closeHandler;
281     std::function<void(Connection&)> errorHandler;
282     std::shared_ptr<persistent_data::UserSession> session;
283 };
284 } // namespace websocket
285 } // namespace crow
286