1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27def parse_address(addr):
28    if addr.startswith(UNIX_PREFIX):
29        return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
30    elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
31        return (ADDR_TYPE_WS, (addr,))
32    else:
33        m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
34        if m is not None:
35            host = m.group("host")
36            port = m.group("port")
37        else:
38            host, port = addr.split(":")
39
40        return (ADDR_TYPE_TCP, (host, int(port)))
41
42class AsyncClient(object):
43    def __init__(
44        self,
45        proto_name,
46        proto_version,
47        logger,
48        timeout=30,
49        server_headers=False,
50        headers={},
51    ):
52        self.socket = None
53        self.max_chunk = DEFAULT_MAX_CHUNK
54        self.proto_name = proto_name
55        self.proto_version = proto_version
56        self.logger = logger
57        self.timeout = timeout
58        self.needs_server_headers = server_headers
59        self.server_headers = {}
60        self.headers = headers
61
62    async def connect_tcp(self, address, port):
63        async def connect_sock():
64            reader, writer = await asyncio.open_connection(address, port)
65            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
66
67        self._connect_sock = connect_sock
68
69    async def connect_unix(self, path):
70        async def connect_sock():
71            # AF_UNIX has path length issues so chdir here to workaround
72            cwd = os.getcwd()
73            try:
74                os.chdir(os.path.dirname(path))
75                # The socket must be opened synchronously so that CWD doesn't get
76                # changed out from underneath us so we pass as a sock into asyncio
77                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
78                sock.connect(os.path.basename(path))
79            finally:
80                os.chdir(cwd)
81            reader, writer = await asyncio.open_unix_connection(sock=sock)
82            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
83
84        self._connect_sock = connect_sock
85
86    async def connect_websocket(self, uri):
87        import websockets
88
89        async def connect_sock():
90            websocket = await websockets.connect(uri, ping_interval=None)
91            return WebsocketConnection(websocket, self.timeout)
92
93        self._connect_sock = connect_sock
94
95    async def setup_connection(self):
96        # Send headers
97        await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
98        await self.socket.send(
99            "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
100        )
101        for k, v in self.headers.items():
102            await self.socket.send("%s: %s" % (k, v))
103
104        # End of headers
105        await self.socket.send("")
106
107        self.server_headers = {}
108        if self.needs_server_headers:
109            while True:
110                line = await self.socket.recv()
111                if not line:
112                    # End headers
113                    break
114                tag, value = line.split(":", 1)
115                self.server_headers[tag.lower()] = value.strip()
116
117    async def get_header(self, tag, default):
118        await self.connect()
119        return self.server_headers.get(tag, default)
120
121    async def connect(self):
122        if self.socket is None:
123            self.socket = await self._connect_sock()
124            await self.setup_connection()
125
126    async def disconnect(self):
127        if self.socket is not None:
128            await self.socket.close()
129            self.socket = None
130
131    async def close(self):
132        await self.disconnect()
133
134    async def _send_wrapper(self, proc):
135        count = 0
136        while True:
137            try:
138                await self.connect()
139                return await proc()
140            except (
141                OSError,
142                ConnectionError,
143                ConnectionClosedError,
144                json.JSONDecodeError,
145                UnicodeDecodeError,
146            ) as e:
147                self.logger.warning("Error talking to server: %s" % e)
148                if count >= 3:
149                    if not isinstance(e, ConnectionError):
150                        raise ConnectionError(str(e))
151                    raise e
152                await self.close()
153                count += 1
154
155    def check_invoke_error(self, msg):
156        if isinstance(msg, dict) and "invoke-error" in msg:
157            raise InvokeError(msg["invoke-error"]["message"])
158
159    async def invoke(self, msg):
160        async def proc():
161            await self.socket.send_message(msg)
162            return await self.socket.recv_message()
163
164        result = await self._send_wrapper(proc)
165        self.check_invoke_error(result)
166        return result
167
168    async def ping(self):
169        return await self.invoke({"ping": {}})
170
171    async def __aenter__(self):
172        return self
173
174    async def __aexit__(self, exc_type, exc_value, traceback):
175        await self.close()
176
177
178class Client(object):
179    def __init__(self):
180        self.client = self._get_async_client()
181        self.loop = asyncio.new_event_loop()
182
183        # Override any pre-existing loop.
184        # Without this, the PR server export selftest triggers a hang
185        # when running with Python 3.7.  The drawback is that there is
186        # potential for issues if the PR and hash equiv (or some new)
187        # clients need to both be instantiated in the same process.
188        # This should be revisited if/when Python 3.9 becomes the
189        # minimum required version for BitBake, as it seems not
190        # required (but harmless) with it.
191        asyncio.set_event_loop(self.loop)
192
193        self._add_methods("connect_tcp", "ping")
194
195    @abc.abstractmethod
196    def _get_async_client(self):
197        pass
198
199    def _get_downcall_wrapper(self, downcall):
200        def wrapper(*args, **kwargs):
201            return self.loop.run_until_complete(downcall(*args, **kwargs))
202
203        return wrapper
204
205    def _add_methods(self, *methods):
206        for m in methods:
207            downcall = getattr(self.client, m)
208            setattr(self, m, self._get_downcall_wrapper(downcall))
209
210    def connect_unix(self, path):
211        self.loop.run_until_complete(self.client.connect_unix(path))
212        self.loop.run_until_complete(self.client.connect())
213
214    @property
215    def max_chunk(self):
216        return self.client.max_chunk
217
218    @max_chunk.setter
219    def max_chunk(self, value):
220        self.client.max_chunk = value
221
222    def disconnect(self):
223        self.loop.run_until_complete(self.client.close())
224
225    def close(self):
226        if self.loop:
227            self.loop.run_until_complete(self.client.close())
228            if sys.version_info >= (3, 6):
229                self.loop.run_until_complete(self.loop.shutdown_asyncgens())
230            self.loop.close()
231        self.loop = None
232
233    def __enter__(self):
234        return self
235
236    def __exit__(self, exc_type, exc_value, traceback):
237        self.close()
238        return False
239
240
241class ClientPool(object):
242    def __init__(self, max_clients):
243        self.avail_clients = []
244        self.num_clients = 0
245        self.max_clients = max_clients
246        self.loop = None
247        self.client_condition = None
248
249    @abc.abstractmethod
250    async def _new_client(self):
251        raise NotImplementedError("Must be implemented in derived class")
252
253    def close(self):
254        if self.client_condition:
255            self.client_condition = None
256
257        if self.loop:
258            self.loop.run_until_complete(self.__close_clients())
259            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
260            self.loop.close()
261            self.loop = None
262
263    def run_tasks(self, tasks):
264        if not self.loop:
265            self.loop = asyncio.new_event_loop()
266
267        thread = Thread(target=self.__thread_main, args=(tasks,))
268        thread.start()
269        thread.join()
270
271    @contextlib.asynccontextmanager
272    async def get_client(self):
273        async with self.client_condition:
274            if self.avail_clients:
275                client = self.avail_clients.pop()
276            elif self.num_clients < self.max_clients:
277                self.num_clients += 1
278                client = await self._new_client()
279            else:
280                while not self.avail_clients:
281                    await self.client_condition.wait()
282                client = self.avail_clients.pop()
283
284        try:
285            yield client
286        finally:
287            async with self.client_condition:
288                self.avail_clients.append(client)
289                self.client_condition.notify()
290
291    def __thread_main(self, tasks):
292        async def process_task(task):
293            async with self.get_client() as client:
294                await task(client)
295
296        asyncio.set_event_loop(self.loop)
297        if not self.client_condition:
298            self.client_condition = asyncio.Condition()
299        tasks = [process_task(t) for t in tasks]
300        self.loop.run_until_complete(asyncio.gather(*tasks))
301
302    async def __close_clients(self):
303        for c in self.avail_clients:
304            await c.close()
305        self.avail_clients = []
306        self.num_clients = 0
307
308    def __enter__(self):
309        return self
310
311    def __exit__(self, exc_type, exc_value, traceback):
312        self.close()
313        return False
314