1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
18
19
20class ClientLoggerAdapter(logging.LoggerAdapter):
21    def process(self, msg, kwargs):
22        return f"[Client {self.extra['address']}] {msg}", kwargs
23
24
25class AsyncServerConnection(object):
26    # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27    # return message will be automatically be sent back to the client
28    NO_RESPONSE = object()
29
30    def __init__(self, socket, proto_name, logger):
31        self.socket = socket
32        self.proto_name = proto_name
33        self.handlers = {
34            "ping": self.handle_ping,
35        }
36        self.logger = ClientLoggerAdapter(
37            logger,
38            {
39                "address": socket.address,
40            },
41        )
42        self.client_headers = {}
43
44    async def close(self):
45        await self.socket.close()
46
47    async def handle_headers(self, headers):
48        return {}
49
50    async def process_requests(self):
51        try:
52            self.logger.info("Client %r connected" % (self.socket.address,))
53
54            # Read protocol and version
55            client_protocol = await self.socket.recv()
56            if not client_protocol:
57                return
58
59            (client_proto_name, client_proto_version) = client_protocol.split()
60            if client_proto_name != self.proto_name:
61                self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
62                return
63
64            self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
65            if not self.validate_proto_version():
66                self.logger.debug(
67                    "Rejecting invalid protocol version %s" % (client_proto_version)
68                )
69                return
70
71            # Read headers
72            self.client_headers = {}
73            while True:
74                header = await self.socket.recv()
75                if not header:
76                    # Empty line. End of headers
77                    break
78                tag, value = header.split(":", 1)
79                self.client_headers[tag.lower()] = value.strip()
80
81            if self.client_headers.get("needs-headers", "false") == "true":
82                for k, v in (await self.handle_headers(self.client_headers)).items():
83                    await self.socket.send("%s: %s" % (k, v))
84                await self.socket.send("")
85
86            # Handle messages
87            while True:
88                d = await self.socket.recv_message()
89                if d is None:
90                    break
91                try:
92                    response = await self.dispatch_message(d)
93                except InvokeError as e:
94                    await self.socket.send_message(
95                        {"invoke-error": {"message": str(e)}}
96                    )
97                    break
98
99                if response is not self.NO_RESPONSE:
100                    await self.socket.send_message(response)
101
102        except ConnectionClosedError as e:
103            self.logger.info(str(e))
104        except (ClientError, ConnectionError) as e:
105            self.logger.error(str(e))
106        finally:
107            await self.close()
108
109    async def dispatch_message(self, msg):
110        for k in self.handlers.keys():
111            if k in msg:
112                self.logger.debug("Handling %s" % k)
113                return await self.handlers[k](msg[k])
114
115        raise ClientError("Unrecognized command %r" % msg)
116
117    async def handle_ping(self, request):
118        return {"alive": True}
119
120
121class StreamServer(object):
122    def __init__(self, handler, logger):
123        self.handler = handler
124        self.logger = logger
125        self.closed = False
126
127    async def handle_stream_client(self, reader, writer):
128        # writer.transport.set_write_buffer_limits(0)
129        socket = StreamConnection(reader, writer, -1)
130        if self.closed:
131            await socket.close()
132            return
133
134        await self.handler(socket)
135
136    async def stop(self):
137        self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141    def __init__(self, host, port, handler, logger):
142        super().__init__(handler, logger)
143        self.host = host
144        self.port = port
145
146    def start(self, loop):
147        self.server = loop.run_until_complete(
148            asyncio.start_server(self.handle_stream_client, self.host, self.port)
149        )
150
151        for s in self.server.sockets:
152            self.logger.debug("Listening on %r" % (s.getsockname(),))
153            # Newer python does this automatically. Do it manually here for
154            # maximum compatibility
155            s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
156            s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
157
158            # Enable keep alives. This prevents broken client connections
159            # from persisting on the server for long periods of time.
160            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
161            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
162            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
163            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
164
165        name = self.server.sockets[0].getsockname()
166        if self.server.sockets[0].family == socket.AF_INET6:
167            self.address = "[%s]:%d" % (name[0], name[1])
168        else:
169            self.address = "%s:%d" % (name[0], name[1])
170
171        return [self.server.wait_closed()]
172
173    async def stop(self):
174        await super().stop()
175        self.server.close()
176
177    def cleanup(self):
178        pass
179
180
181class UnixStreamServer(StreamServer):
182    def __init__(self, path, handler, logger):
183        super().__init__(handler, logger)
184        self.path = path
185
186    def start(self, loop):
187        cwd = os.getcwd()
188        try:
189            # Work around path length limits in AF_UNIX
190            os.chdir(os.path.dirname(self.path))
191            self.server = loop.run_until_complete(
192                asyncio.start_unix_server(
193                    self.handle_stream_client, os.path.basename(self.path)
194                )
195            )
196        finally:
197            os.chdir(cwd)
198
199        self.logger.debug("Listening on %r" % self.path)
200        self.address = "unix://%s" % os.path.abspath(self.path)
201        return [self.server.wait_closed()]
202
203    async def stop(self):
204        await super().stop()
205        self.server.close()
206
207    def cleanup(self):
208        os.unlink(self.path)
209
210
211class WebsocketsServer(object):
212    def __init__(self, host, port, handler, logger):
213        self.host = host
214        self.port = port
215        self.handler = handler
216        self.logger = logger
217
218    def start(self, loop):
219        import websockets.server
220
221        self.server = loop.run_until_complete(
222            websockets.server.serve(
223                self.client_handler,
224                self.host,
225                self.port,
226                ping_interval=None,
227            )
228        )
229
230        for s in self.server.sockets:
231            self.logger.debug("Listening on %r" % (s.getsockname(),))
232
233            # Enable keep alives. This prevents broken client connections
234            # from persisting on the server for long periods of time.
235            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
236            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
237            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
238            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
239
240        name = self.server.sockets[0].getsockname()
241        if self.server.sockets[0].family == socket.AF_INET6:
242            self.address = "ws://[%s]:%d" % (name[0], name[1])
243        else:
244            self.address = "ws://%s:%d" % (name[0], name[1])
245
246        return [self.server.wait_closed()]
247
248    async def stop(self):
249        self.server.close()
250
251    def cleanup(self):
252        pass
253
254    async def client_handler(self, websocket):
255        socket = WebsocketConnection(websocket, -1)
256        await self.handler(socket)
257
258
259class AsyncServer(object):
260    def __init__(self, logger):
261        self.logger = logger
262        self.loop = None
263        self.run_tasks = []
264
265    def start_tcp_server(self, host, port):
266        self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
267
268    def start_unix_server(self, path):
269        self.server = UnixStreamServer(path, self._client_handler, self.logger)
270
271    def start_websocket_server(self, host, port):
272        self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
273
274    async def _client_handler(self, socket):
275        address = socket.address
276        try:
277            client = self.accept_client(socket)
278            await client.process_requests()
279        except Exception as e:
280            import traceback
281
282            self.logger.error(
283                "Error from client %s: %s" % (address, str(e)), exc_info=True
284            )
285            traceback.print_exc()
286        finally:
287            self.logger.debug("Client %s disconnected", address)
288            await socket.close()
289
290    @abc.abstractmethod
291    def accept_client(self, socket):
292        pass
293
294    async def stop(self):
295        self.logger.debug("Stopping server")
296        await self.server.stop()
297
298    def start(self):
299        tasks = self.server.start(self.loop)
300        self.address = self.server.address
301        return tasks
302
303    def signal_handler(self):
304        self.logger.debug("Got exit signal")
305        self.loop.create_task(self.stop())
306
307    def _serve_forever(self, tasks):
308        try:
309            self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
310            self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
311            self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
312            signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
313
314            self.loop.run_until_complete(asyncio.gather(*tasks))
315
316            self.logger.debug("Server shutting down")
317        finally:
318            self.server.cleanup()
319
320    def serve_forever(self):
321        """
322        Serve requests in the current process
323        """
324        self._create_loop()
325        tasks = self.start()
326        self._serve_forever(tasks)
327        self.loop.close()
328
329    def _create_loop(self):
330        # Create loop and override any loop that may have existed in
331        # a parent process.  It is possible that the usecases of
332        # serve_forever might be constrained enough to allow using
333        # get_event_loop here, but better safe than sorry for now.
334        self.loop = asyncio.new_event_loop()
335        asyncio.set_event_loop(self.loop)
336
337    def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
338        """
339        Serve requests in a child process
340        """
341
342        def run(queue):
343            # Create loop and override any loop that may have existed
344            # in a parent process.  Without doing this and instead
345            # using get_event_loop, at the very minimum the hashserv
346            # unit tests will hang when running the second test.
347            # This happens since get_event_loop in the spawned server
348            # process for the second testcase ends up with the loop
349            # from the hashserv client created in the unit test process
350            # when running the first testcase.  The problem is somewhat
351            # more general, though, as any potential use of asyncio in
352            # Cooker could create a loop that needs to replaced in this
353            # new process.
354            self._create_loop()
355            try:
356                self.address = None
357                tasks = self.start()
358            finally:
359                # Always put the server address to wake up the parent task
360                queue.put(self.address)
361                queue.close()
362
363            if prefunc is not None:
364                prefunc(self, *args)
365
366            if log_level is not None:
367                self.logger.setLevel(log_level)
368
369            self._serve_forever(tasks)
370
371            if sys.version_info >= (3, 6):
372                self.loop.run_until_complete(self.loop.shutdown_asyncgens())
373            self.loop.close()
374
375        queue = multiprocessing.Queue()
376
377        # Temporarily block SIGTERM. The server process will inherit this
378        # block which will ensure it doesn't receive the SIGTERM until the
379        # handler is ready for it
380        mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
381        try:
382            self.process = multiprocessing.Process(target=run, args=(queue,))
383            self.process.start()
384
385            self.address = queue.get()
386            queue.close()
387            queue.join_thread()
388
389            return self.process
390        finally:
391            signal.pthread_sigmask(signal.SIG_SETMASK, mask)
392