xref: /openbmc/openbmc/poky/bitbake/lib/bb/asyncrpc/serv.py (revision edff49234e31f23dc79f823473c9e286a21596c1)
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
18
19
20class ClientLoggerAdapter(logging.LoggerAdapter):
21    def process(self, msg, kwargs):
22        return f"[Client {self.extra['address']}] {msg}", kwargs
23
24
25class AsyncServerConnection(object):
26    # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27    # return message will be automatically be sent back to the client
28    NO_RESPONSE = object()
29
30    def __init__(self, socket, proto_name, logger):
31        self.socket = socket
32        self.proto_name = proto_name
33        self.handlers = {
34            "ping": self.handle_ping,
35        }
36        self.logger = ClientLoggerAdapter(
37            logger,
38            {
39                "address": socket.address,
40            },
41        )
42        self.client_headers = {}
43
44    async def close(self):
45        await self.socket.close()
46
47    async def handle_headers(self, headers):
48        return {}
49
50    async def process_requests(self):
51        try:
52            self.logger.info("Client %r connected" % (self.socket.address,))
53
54            # Read protocol and version
55            client_protocol = await self.socket.recv()
56            if not client_protocol:
57                return
58
59            (client_proto_name, client_proto_version) = client_protocol.split()
60            if client_proto_name != self.proto_name:
61                self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
62                return
63
64            self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
65            if not self.validate_proto_version():
66                self.logger.debug(
67                    "Rejecting invalid protocol version %s" % (client_proto_version)
68                )
69                return
70
71            # Read headers
72            self.client_headers = {}
73            while True:
74                header = await self.socket.recv()
75                if not header:
76                    # Empty line. End of headers
77                    break
78                tag, value = header.split(":", 1)
79                self.client_headers[tag.lower()] = value.strip()
80
81            if self.client_headers.get("needs-headers", "false") == "true":
82                for k, v in (await self.handle_headers(self.client_headers)).items():
83                    await self.socket.send("%s: %s" % (k, v))
84                await self.socket.send("")
85
86            # Handle messages
87            while True:
88                d = await self.socket.recv_message()
89                if d is None:
90                    break
91                try:
92                    response = await self.dispatch_message(d)
93                except InvokeError as e:
94                    await self.socket.send_message(
95                        {"invoke-error": {"message": str(e)}}
96                    )
97                    break
98
99                if response is not self.NO_RESPONSE:
100                    await self.socket.send_message(response)
101
102        except ConnectionClosedError as e:
103            self.logger.info(str(e))
104        except (ClientError, ConnectionError) as e:
105            self.logger.error(str(e))
106        finally:
107            await self.close()
108
109    async def dispatch_message(self, msg):
110        for k in self.handlers.keys():
111            if k in msg:
112                self.logger.debug("Handling %s" % k)
113                return await self.handlers[k](msg[k])
114
115        raise ClientError("Unrecognized command %r" % msg)
116
117    async def handle_ping(self, request):
118        return {"alive": True}
119
120
121class StreamServer(object):
122    def __init__(self, handler, logger):
123        self.handler = handler
124        self.logger = logger
125        self.closed = False
126
127    async def handle_stream_client(self, reader, writer):
128        # writer.transport.set_write_buffer_limits(0)
129        socket = StreamConnection(reader, writer, -1)
130        if self.closed:
131            await socket.close()
132            return
133
134        await self.handler(socket)
135
136    async def stop(self):
137        self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141    def __init__(self, host, port, handler, logger, *, reuseport=False):
142        super().__init__(handler, logger)
143        self.host = host
144        self.port = port
145        self.reuseport = reuseport
146
147    def start(self, loop):
148        self.server = loop.run_until_complete(
149            asyncio.start_server(
150                self.handle_stream_client,
151                self.host,
152                self.port,
153                reuse_port=self.reuseport,
154            )
155        )
156
157        for s in self.server.sockets:
158            self.logger.debug("Listening on %r" % (s.getsockname(),))
159            # Newer python does this automatically. Do it manually here for
160            # maximum compatibility
161            s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
162            s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
163
164            # Enable keep alives. This prevents broken client connections
165            # from persisting on the server for long periods of time.
166            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
167            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
168            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
169            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
170
171        name = self.server.sockets[0].getsockname()
172        if self.server.sockets[0].family == socket.AF_INET6:
173            self.address = "[%s]:%d" % (name[0], name[1])
174        else:
175            self.address = "%s:%d" % (name[0], name[1])
176
177        return [self.server.wait_closed()]
178
179    async def stop(self):
180        await super().stop()
181        self.server.close()
182
183    def cleanup(self):
184        pass
185
186
187class UnixStreamServer(StreamServer):
188    def __init__(self, path, handler, logger):
189        super().__init__(handler, logger)
190        self.path = path
191
192    def start(self, loop):
193        cwd = os.getcwd()
194        try:
195            # Work around path length limits in AF_UNIX
196            os.chdir(os.path.dirname(self.path))
197            self.server = loop.run_until_complete(
198                asyncio.start_unix_server(
199                    self.handle_stream_client, os.path.basename(self.path)
200                )
201            )
202        finally:
203            os.chdir(cwd)
204
205        self.logger.debug("Listening on %r" % self.path)
206        self.address = "unix://%s" % os.path.abspath(self.path)
207        return [self.server.wait_closed()]
208
209    async def stop(self):
210        await super().stop()
211        self.server.close()
212
213    def cleanup(self):
214        os.unlink(self.path)
215
216
217class WebsocketsServer(object):
218    def __init__(self, host, port, handler, logger, *, reuseport=False):
219        self.host = host
220        self.port = port
221        self.handler = handler
222        self.logger = logger
223        self.reuseport = reuseport
224
225    def start(self, loop):
226        import websockets.server
227
228        self.server = loop.run_until_complete(
229            websockets.server.serve(
230                self.client_handler,
231                self.host,
232                self.port,
233                ping_interval=None,
234                reuse_port=self.reuseport,
235            )
236        )
237
238        for s in self.server.sockets:
239            self.logger.debug("Listening on %r" % (s.getsockname(),))
240
241            # Enable keep alives. This prevents broken client connections
242            # from persisting on the server for long periods of time.
243            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
244            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
245            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
246            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
247
248        name = self.server.sockets[0].getsockname()
249        if self.server.sockets[0].family == socket.AF_INET6:
250            self.address = "ws://[%s]:%d" % (name[0], name[1])
251        else:
252            self.address = "ws://%s:%d" % (name[0], name[1])
253
254        return [self.server.wait_closed()]
255
256    async def stop(self):
257        self.server.close()
258
259    def cleanup(self):
260        pass
261
262    async def client_handler(self, websocket):
263        socket = WebsocketConnection(websocket, -1)
264        await self.handler(socket)
265
266
267class AsyncServer(object):
268    def __init__(self, logger):
269        self.logger = logger
270        self.loop = None
271        self.run_tasks = []
272
273    def start_tcp_server(self, host, port, *, reuseport=False):
274        self.server = TCPStreamServer(
275            host,
276            port,
277            self._client_handler,
278            self.logger,
279            reuseport=reuseport,
280        )
281
282    def start_unix_server(self, path):
283        self.server = UnixStreamServer(path, self._client_handler, self.logger)
284
285    def start_websocket_server(self, host, port, reuseport=False):
286        self.server = WebsocketsServer(
287            host,
288            port,
289            self._client_handler,
290            self.logger,
291            reuseport=reuseport,
292        )
293
294    async def _client_handler(self, socket):
295        address = socket.address
296        try:
297            client = self.accept_client(socket)
298            await client.process_requests()
299        except Exception as e:
300            import traceback
301
302            self.logger.error(
303                "Error from client %s: %s" % (address, str(e)), exc_info=True
304            )
305            traceback.print_exc()
306        finally:
307            self.logger.debug("Client %s disconnected", address)
308            await socket.close()
309
310    @abc.abstractmethod
311    def accept_client(self, socket):
312        pass
313
314    async def stop(self):
315        self.logger.debug("Stopping server")
316        await self.server.stop()
317
318    def start(self):
319        tasks = self.server.start(self.loop)
320        self.address = self.server.address
321        return tasks
322
323    def signal_handler(self):
324        self.logger.debug("Got exit signal")
325        self.loop.create_task(self.stop())
326
327    def _serve_forever(self, tasks):
328        try:
329            self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
330            self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
331            self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
332            signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
333
334            self.loop.run_until_complete(asyncio.gather(*tasks))
335
336            self.logger.debug("Server shutting down")
337        finally:
338            self.server.cleanup()
339
340    def serve_forever(self):
341        """
342        Serve requests in the current process
343        """
344        self._create_loop()
345        tasks = self.start()
346        self._serve_forever(tasks)
347        self.loop.close()
348
349    def _create_loop(self):
350        # Create loop and override any loop that may have existed in
351        # a parent process.  It is possible that the usecases of
352        # serve_forever might be constrained enough to allow using
353        # get_event_loop here, but better safe than sorry for now.
354        self.loop = asyncio.new_event_loop()
355        asyncio.set_event_loop(self.loop)
356
357    def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
358        """
359        Serve requests in a child process
360        """
361
362        def run(queue):
363            # Create loop and override any loop that may have existed
364            # in a parent process.  Without doing this and instead
365            # using get_event_loop, at the very minimum the hashserv
366            # unit tests will hang when running the second test.
367            # This happens since get_event_loop in the spawned server
368            # process for the second testcase ends up with the loop
369            # from the hashserv client created in the unit test process
370            # when running the first testcase.  The problem is somewhat
371            # more general, though, as any potential use of asyncio in
372            # Cooker could create a loop that needs to replaced in this
373            # new process.
374            self._create_loop()
375            try:
376                self.address = None
377                tasks = self.start()
378            finally:
379                # Always put the server address to wake up the parent task
380                queue.put(self.address)
381                queue.close()
382
383            if prefunc is not None:
384                prefunc(self, *args)
385
386            if log_level is not None:
387                self.logger.setLevel(log_level)
388
389            self._serve_forever(tasks)
390
391            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
392            self.loop.close()
393
394        queue = multiprocessing.Queue()
395
396        # Temporarily block SIGTERM. The server process will inherit this
397        # block which will ensure it doesn't receive the SIGTERM until the
398        # handler is ready for it
399        mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
400        try:
401            self.process = multiprocessing.Process(target=run, args=(queue,))
402            self.process.start()
403
404            self.address = queue.get()
405            queue.close()
406            queue.join_thread()
407
408            return self.process
409        finally:
410            signal.pthread_sigmask(signal.SIG_SETMASK, mask)
411