1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15from . import chunkify, DEFAULT_MAX_CHUNK
16
17
18class ClientError(Exception):
19    pass
20
21
22class ServerError(Exception):
23    pass
24
25
26class AsyncServerConnection(object):
27    def __init__(self, reader, writer, proto_name, logger):
28        self.reader = reader
29        self.writer = writer
30        self.proto_name = proto_name
31        self.max_chunk = DEFAULT_MAX_CHUNK
32        self.handlers = {
33            'chunk-stream': self.handle_chunk,
34            'ping': self.handle_ping,
35        }
36        self.logger = logger
37
38    async def process_requests(self):
39        try:
40            self.addr = self.writer.get_extra_info('peername')
41            self.logger.debug('Client %r connected' % (self.addr,))
42
43            # Read protocol and version
44            client_protocol = await self.reader.readline()
45            if client_protocol is None:
46                return
47
48            (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
49            if client_proto_name != self.proto_name:
50                self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
51                return
52
53            self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
54            if not self.validate_proto_version():
55                self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
56                return
57
58            # Read headers. Currently, no headers are implemented, so look for
59            # an empty line to signal the end of the headers
60            while True:
61                line = await self.reader.readline()
62                if line is None:
63                    return
64
65                line = line.decode('utf-8').rstrip()
66                if not line:
67                    break
68
69            # Handle messages
70            while True:
71                d = await self.read_message()
72                if d is None:
73                    break
74                await self.dispatch_message(d)
75                await self.writer.drain()
76        except ClientError as e:
77            self.logger.error(str(e))
78        finally:
79            self.writer.close()
80
81    async def dispatch_message(self, msg):
82        for k in self.handlers.keys():
83            if k in msg:
84                self.logger.debug('Handling %s' % k)
85                await self.handlers[k](msg[k])
86                return
87
88        raise ClientError("Unrecognized command %r" % msg)
89
90    def write_message(self, msg):
91        for c in chunkify(json.dumps(msg), self.max_chunk):
92            self.writer.write(c.encode('utf-8'))
93
94    async def read_message(self):
95        l = await self.reader.readline()
96        if not l:
97            return None
98
99        try:
100            message = l.decode('utf-8')
101
102            if not message.endswith('\n'):
103                return None
104
105            return json.loads(message)
106        except (json.JSONDecodeError, UnicodeDecodeError) as e:
107            self.logger.error('Bad message from client: %r' % message)
108            raise e
109
110    async def handle_chunk(self, request):
111        lines = []
112        try:
113            while True:
114                l = await self.reader.readline()
115                l = l.rstrip(b"\n").decode("utf-8")
116                if not l:
117                    break
118                lines.append(l)
119
120            msg = json.loads(''.join(lines))
121        except (json.JSONDecodeError, UnicodeDecodeError) as e:
122            self.logger.error('Bad message from client: %r' % lines)
123            raise e
124
125        if 'chunk-stream' in msg:
126            raise ClientError("Nested chunks are not allowed")
127
128        await self.dispatch_message(msg)
129
130    async def handle_ping(self, request):
131        response = {'alive': True}
132        self.write_message(response)
133
134
135class AsyncServer(object):
136    def __init__(self, logger):
137        self._cleanup_socket = None
138        self.logger = logger
139        self.start = None
140        self.address = None
141        self.loop = None
142
143    def start_tcp_server(self, host, port):
144        def start_tcp():
145            self.server = self.loop.run_until_complete(
146                asyncio.start_server(self.handle_client, host, port)
147            )
148
149            for s in self.server.sockets:
150                self.logger.debug('Listening on %r' % (s.getsockname(),))
151                # Newer python does this automatically. Do it manually here for
152                # maximum compatibility
153                s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
154                s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
155
156                # Enable keep alives. This prevents broken client connections
157                # from persisting on the server for long periods of time.
158                s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
159                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
160                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
161                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
162
163            name = self.server.sockets[0].getsockname()
164            if self.server.sockets[0].family == socket.AF_INET6:
165                self.address = "[%s]:%d" % (name[0], name[1])
166            else:
167                self.address = "%s:%d" % (name[0], name[1])
168
169        self.start = start_tcp
170
171    def start_unix_server(self, path):
172        def cleanup():
173            os.unlink(path)
174
175        def start_unix():
176            cwd = os.getcwd()
177            try:
178                # Work around path length limits in AF_UNIX
179                os.chdir(os.path.dirname(path))
180                self.server = self.loop.run_until_complete(
181                    asyncio.start_unix_server(self.handle_client, os.path.basename(path))
182                )
183            finally:
184                os.chdir(cwd)
185
186            self.logger.debug('Listening on %r' % path)
187
188            self._cleanup_socket = cleanup
189            self.address = "unix://%s" % os.path.abspath(path)
190
191        self.start = start_unix
192
193    @abc.abstractmethod
194    def accept_client(self, reader, writer):
195        pass
196
197    async def handle_client(self, reader, writer):
198        # writer.transport.set_write_buffer_limits(0)
199        try:
200            client = self.accept_client(reader, writer)
201            await client.process_requests()
202        except Exception as e:
203            import traceback
204            self.logger.error('Error from client: %s' % str(e), exc_info=True)
205            traceback.print_exc()
206            writer.close()
207        self.logger.debug('Client disconnected')
208
209    def run_loop_forever(self):
210        try:
211            self.loop.run_forever()
212        except KeyboardInterrupt:
213            pass
214
215    def signal_handler(self):
216        self.logger.debug("Got exit signal")
217        self.loop.stop()
218
219    def _serve_forever(self):
220        try:
221            self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
222            signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
223
224            self.run_loop_forever()
225            self.server.close()
226
227            self.loop.run_until_complete(self.server.wait_closed())
228            self.logger.debug('Server shutting down')
229        finally:
230            if self._cleanup_socket is not None:
231                self._cleanup_socket()
232
233    def serve_forever(self):
234        """
235        Serve requests in the current process
236        """
237        # Create loop and override any loop that may have existed in
238        # a parent process.  It is possible that the usecases of
239        # serve_forever might be constrained enough to allow using
240        # get_event_loop here, but better safe than sorry for now.
241        self.loop = asyncio.new_event_loop()
242        asyncio.set_event_loop(self.loop)
243        self.start()
244        self._serve_forever()
245
246    def serve_as_process(self, *, prefunc=None, args=()):
247        """
248        Serve requests in a child process
249        """
250        def run(queue):
251            # Create loop and override any loop that may have existed
252            # in a parent process.  Without doing this and instead
253            # using get_event_loop, at the very minimum the hashserv
254            # unit tests will hang when running the second test.
255            # This happens since get_event_loop in the spawned server
256            # process for the second testcase ends up with the loop
257            # from the hashserv client created in the unit test process
258            # when running the first testcase.  The problem is somewhat
259            # more general, though, as any potential use of asyncio in
260            # Cooker could create a loop that needs to replaced in this
261            # new process.
262            self.loop = asyncio.new_event_loop()
263            asyncio.set_event_loop(self.loop)
264            try:
265                self.start()
266            finally:
267                queue.put(self.address)
268                queue.close()
269
270            if prefunc is not None:
271                prefunc(self, *args)
272
273            self._serve_forever()
274
275            if sys.version_info >= (3, 6):
276                self.loop.run_until_complete(self.loop.shutdown_asyncgens())
277            self.loop.close()
278
279        queue = multiprocessing.Queue()
280
281        # Temporarily block SIGTERM. The server process will inherit this
282        # block which will ensure it doesn't receive the SIGTERM until the
283        # handler is ready for it
284        mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
285        try:
286            self.process = multiprocessing.Process(target=run, args=(queue,))
287            self.process.start()
288
289            self.address = queue.get()
290            queue.close()
291            queue.join_thread()
292
293            return self.process
294        finally:
295            signal.pthread_sigmask(signal.SIG_SETMASK, mask)
296