xref: /openbmc/bmcweb/http/websocket.hpp (revision 262d7d4b)
1 #pragma once
2 #include "http_request.hpp"
3 
4 #include <async_resp.hpp>
5 #include <boost/algorithm/string/predicate.hpp>
6 #include <boost/asio/buffer.hpp>
7 #include <boost/beast/websocket.hpp>
8 
9 #include <array>
10 #include <functional>
11 
12 #ifdef BMCWEB_ENABLE_SSL
13 #include <boost/beast/websocket/ssl.hpp>
14 #endif
15 
16 namespace crow
17 {
18 namespace websocket
19 {
20 
21 struct Connection : std::enable_shared_from_this<Connection>
22 {
23   public:
24     explicit Connection(const crow::Request& reqIn) :
25         req(reqIn.req), userdataPtr(nullptr)
26     {}
27 
28     explicit Connection(const crow::Request& reqIn, std::string user) :
29         req(reqIn.req), userName{std::move(user)}, userdataPtr(nullptr)
30     {}
31 
32     virtual void sendBinary(const std::string_view msg) = 0;
33     virtual void sendBinary(std::string&& msg) = 0;
34     virtual void sendText(const std::string_view msg) = 0;
35     virtual void sendText(std::string&& msg) = 0;
36     virtual void close(const std::string_view msg = "quit") = 0;
37     virtual boost::asio::io_context& getIoContext() = 0;
38     virtual ~Connection() = default;
39 
40     void userdata(void* u)
41     {
42         userdataPtr = u;
43     }
44     void* userdata()
45     {
46         return userdataPtr;
47     }
48 
49     const std::string& getUserName() const
50     {
51         return userName;
52     }
53 
54     boost::beast::http::request<boost::beast::http::string_body> req;
55     crow::Response res;
56 
57   private:
58     std::string userName{};
59     void* userdataPtr;
60 };
61 
62 template <typename Adaptor>
63 class ConnectionImpl : public Connection
64 {
65   public:
66     ConnectionImpl(
67         const crow::Request& reqIn, Adaptor adaptorIn,
68         std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
69             openHandler,
70         std::function<void(Connection&, const std::string&, bool)>
71             messageHandler,
72         std::function<void(Connection&, const std::string&)> closeHandler,
73         std::function<void(Connection&)> errorHandler) :
74         Connection(reqIn, reqIn.session->username),
75         ws(std::move(adaptorIn)), inString(), inBuffer(inString, 131088),
76         openHandler(std::move(openHandler)),
77         messageHandler(std::move(messageHandler)),
78         closeHandler(std::move(closeHandler)),
79         errorHandler(std::move(errorHandler)), session(reqIn.session)
80     {
81         BMCWEB_LOG_DEBUG << "Creating new connection " << this;
82     }
83 
84     boost::asio::io_context& getIoContext() override
85     {
86         return static_cast<boost::asio::io_context&>(
87             ws.get_executor().context());
88     }
89 
90     void start()
91     {
92         BMCWEB_LOG_DEBUG << "starting connection " << this;
93 
94         using bf = boost::beast::http::field;
95 
96         std::string_view protocol = req[bf::sec_websocket_protocol];
97 
98         ws.set_option(boost::beast::websocket::stream_base::decorator(
99             [session{session}, protocol{std::string(protocol)}](
100                 boost::beast::websocket::response_type& m) {
101 
102 #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
103                 if (session != nullptr)
104                 {
105                     // use protocol for csrf checking
106                     if (session->cookieAuth &&
107                         !crow::utility::constantTimeStringCompare(
108                             protocol, session->csrfToken))
109                     {
110                         BMCWEB_LOG_ERROR << "Websocket CSRF error";
111                         m.result(boost::beast::http::status::unauthorized);
112                         return;
113                     }
114                 }
115 #endif
116                 if (!protocol.empty())
117                 {
118                     m.insert(bf::sec_websocket_protocol, protocol);
119                 }
120 
121                 m.insert(bf::strict_transport_security, "max-age=31536000; "
122                                                         "includeSubdomains; "
123                                                         "preload");
124                 m.insert(bf::pragma, "no-cache");
125                 m.insert(bf::cache_control, "no-Store,no-Cache");
126                 m.insert("Content-Security-Policy", "default-src 'self'");
127                 m.insert("X-XSS-Protection", "1; "
128                                              "mode=block");
129                 m.insert("X-Content-Type-Options", "nosniff");
130             }));
131 
132         // Perform the websocket upgrade
133         ws.async_accept(req, [this, self(shared_from_this())](
134                                  boost::system::error_code ec) {
135             if (ec)
136             {
137                 BMCWEB_LOG_ERROR << "Error in ws.async_accept " << ec;
138                 return;
139             }
140             acceptDone();
141         });
142     }
143 
144     void sendBinary(const std::string_view msg) override
145     {
146         ws.binary(true);
147         outBuffer.emplace_back(msg);
148         doWrite();
149     }
150 
151     void sendBinary(std::string&& msg) override
152     {
153         ws.binary(true);
154         outBuffer.emplace_back(std::move(msg));
155         doWrite();
156     }
157 
158     void sendText(const std::string_view msg) override
159     {
160         ws.text(true);
161         outBuffer.emplace_back(msg);
162         doWrite();
163     }
164 
165     void sendText(std::string&& msg) override
166     {
167         ws.text(true);
168         outBuffer.emplace_back(std::move(msg));
169         doWrite();
170     }
171 
172     void close(const std::string_view msg) override
173     {
174         ws.async_close(
175             {boost::beast::websocket::close_code::normal, msg},
176             [self(shared_from_this())](boost::system::error_code ec) {
177                 if (ec == boost::asio::error::operation_aborted)
178                 {
179                     return;
180                 }
181                 if (ec)
182                 {
183                     BMCWEB_LOG_ERROR << "Error closing websocket " << ec;
184                     return;
185                 }
186             });
187     }
188 
189     void acceptDone()
190     {
191         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
192 
193         auto asyncResp = std::make_shared<bmcweb::AsyncResp>(
194             res, [this, self(shared_from_this())]() { doRead(); });
195 
196         asyncResp->res.result(boost::beast::http::status::ok);
197 
198         if (openHandler)
199         {
200             openHandler(*this, asyncResp);
201         }
202     }
203 
204     void doRead()
205     {
206         ws.async_read(inBuffer,
207                       [this, self(shared_from_this())](
208                           boost::beast::error_code ec, std::size_t bytesRead) {
209                           if (ec)
210                           {
211                               if (ec != boost::beast::websocket::error::closed)
212                               {
213                                   BMCWEB_LOG_ERROR << "doRead error " << ec;
214                               }
215                               if (closeHandler)
216                               {
217                                   std::string_view reason = ws.reason().reason;
218                                   closeHandler(*this, std::string(reason));
219                               }
220                               return;
221                           }
222                           if (messageHandler)
223                           {
224                               messageHandler(*this, inString, ws.got_text());
225                           }
226                           inBuffer.consume(bytesRead);
227                           inString.clear();
228                           doRead();
229                       });
230     }
231 
232     void doWrite()
233     {
234         // If we're already doing a write, ignore the request, it will be picked
235         // up when the current write is complete
236         if (doingWrite)
237         {
238             return;
239         }
240 
241         if (outBuffer.empty())
242         {
243             // Done for now
244             return;
245         }
246         doingWrite = true;
247         ws.async_write(boost::asio::buffer(outBuffer.front()),
248                        [this, self(shared_from_this())](
249                            boost::beast::error_code ec, std::size_t) {
250                            doingWrite = false;
251                            outBuffer.erase(outBuffer.begin());
252                            if (ec == boost::beast::websocket::error::closed)
253                            {
254                                // Do nothing here.  doRead handler will call the
255                                // closeHandler.
256                                close("Write error");
257                                return;
258                            }
259                            if (ec)
260                            {
261                                BMCWEB_LOG_ERROR << "Error in ws.async_write "
262                                                 << ec;
263                                return;
264                            }
265                            doWrite();
266                        });
267     }
268 
269   private:
270     boost::beast::websocket::stream<Adaptor> ws;
271 
272     std::string inString;
273     boost::asio::dynamic_string_buffer<std::string::value_type,
274                                        std::string::traits_type,
275                                        std::string::allocator_type>
276         inBuffer;
277     std::vector<std::string> outBuffer;
278     bool doingWrite = false;
279 
280     std::function<void(Connection&, std::shared_ptr<bmcweb::AsyncResp>)>
281         openHandler;
282     std::function<void(Connection&, const std::string&, bool)> messageHandler;
283     std::function<void(Connection&, const std::string&)> closeHandler;
284     std::function<void(Connection&)> errorHandler;
285     std::shared_ptr<persistent_data::UserSession> session;
286 };
287 } // namespace websocket
288 } // namespace crow
289