xref: /openbmc/openbmc/poky/bitbake/lib/bb/asyncrpc/client.py (revision 96e4b4e121e0e2da1535d7d537d6a982a6ff5bc0)
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30    WEBSOCKETS_MIN_VERSION = (10, 0)
31
32
33def parse_address(addr):
34    if addr.startswith(UNIX_PREFIX):
35        return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
36    elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
37        return (ADDR_TYPE_WS, (addr,))
38    else:
39        m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
40        if m is not None:
41            host = m.group("host")
42            port = m.group("port")
43        else:
44            host, port = addr.split(":")
45
46        return (ADDR_TYPE_TCP, (host, int(port)))
47
48
49class AsyncClient(object):
50    def __init__(
51        self,
52        proto_name,
53        proto_version,
54        logger,
55        timeout=30,
56        server_headers=False,
57        headers={},
58    ):
59        self.socket = None
60        self.max_chunk = DEFAULT_MAX_CHUNK
61        self.proto_name = proto_name
62        self.proto_version = proto_version
63        self.logger = logger
64        self.timeout = timeout
65        self.needs_server_headers = server_headers
66        self.server_headers = {}
67        self.headers = headers
68
69    async def connect_tcp(self, address, port):
70        async def connect_sock():
71            reader, writer = await asyncio.open_connection(address, port)
72            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
73
74        self._connect_sock = connect_sock
75
76    async def connect_unix(self, path):
77        async def connect_sock():
78            # AF_UNIX has path length issues so chdir here to workaround
79            cwd = os.getcwd()
80            try:
81                os.chdir(os.path.dirname(path))
82                # The socket must be opened synchronously so that CWD doesn't get
83                # changed out from underneath us so we pass as a sock into asyncio
84                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
85                sock.connect(os.path.basename(path))
86            finally:
87                os.chdir(cwd)
88            reader, writer = await asyncio.open_unix_connection(sock=sock)
89            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
90
91        self._connect_sock = connect_sock
92
93    async def connect_websocket(self, uri):
94        import websockets
95
96        try:
97            version = tuple(
98                int(v)
99                for v in websockets.__version__.split(".")[
100                    0 : len(WEBSOCKETS_MIN_VERSION)
101                ]
102            )
103        except ValueError:
104            raise ImportError(
105                f"Unable to parse websockets version '{websockets.__version__}'"
106            )
107
108        if version < WEBSOCKETS_MIN_VERSION:
109            min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
110            raise ImportError(
111                f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
112            )
113
114        async def connect_sock():
115            try:
116                websocket = await websockets.connect(
117                    uri,
118                    ping_interval=None,
119                    open_timeout=self.timeout,
120                )
121            except asyncio.exceptions.TimeoutError:
122                raise ConnectionError("Timeout while connecting to websocket")
123            except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc:
124                raise ConnectionError(f"Could not connect to websocket: {exc}") from exc
125            return WebsocketConnection(websocket, self.timeout)
126
127        self._connect_sock = connect_sock
128
129    async def setup_connection(self):
130        # Send headers
131        await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
132        await self.socket.send(
133            "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
134        )
135        for k, v in self.headers.items():
136            await self.socket.send("%s: %s" % (k, v))
137
138        # End of headers
139        await self.socket.send("")
140
141        self.server_headers = {}
142        if self.needs_server_headers:
143            while True:
144                line = await self.socket.recv()
145                if not line:
146                    # End headers
147                    break
148                tag, value = line.split(":", 1)
149                self.server_headers[tag.lower()] = value.strip()
150
151    async def get_header(self, tag, default):
152        await self.connect()
153        return self.server_headers.get(tag, default)
154
155    async def connect(self):
156        if self.socket is None:
157            self.socket = await self._connect_sock()
158            await self.setup_connection()
159
160    async def disconnect(self):
161        if self.socket is not None:
162            await self.socket.close()
163            self.socket = None
164
165    async def close(self):
166        await self.disconnect()
167
168    async def _send_wrapper(self, proc):
169        count = 0
170        while True:
171            try:
172                await self.connect()
173                return await proc()
174            except (
175                OSError,
176                ConnectionError,
177                ConnectionClosedError,
178                json.JSONDecodeError,
179                UnicodeDecodeError,
180            ) as e:
181                self.logger.warning("Error talking to server: %s" % e)
182                if count >= 3:
183                    if not isinstance(e, ConnectionError):
184                        raise ConnectionError(str(e))
185                    raise e
186                await self.close()
187                count += 1
188
189    def check_invoke_error(self, msg):
190        if isinstance(msg, dict) and "invoke-error" in msg:
191            raise InvokeError(msg["invoke-error"]["message"])
192
193    async def invoke(self, msg):
194        async def proc():
195            await self.socket.send_message(msg)
196            return await self.socket.recv_message()
197
198        result = await self._send_wrapper(proc)
199        self.check_invoke_error(result)
200        return result
201
202    async def ping(self):
203        return await self.invoke({"ping": {}})
204
205    async def __aenter__(self):
206        return self
207
208    async def __aexit__(self, exc_type, exc_value, traceback):
209        await self.close()
210
211
212class Client(object):
213    def __init__(self):
214        self.client = self._get_async_client()
215        self.loop = asyncio.new_event_loop()
216
217        # Override any pre-existing loop.
218        # Without this, the PR server export selftest triggers a hang
219        # when running with Python 3.7.  The drawback is that there is
220        # potential for issues if the PR and hash equiv (or some new)
221        # clients need to both be instantiated in the same process.
222        # This should be revisited if/when Python 3.9 becomes the
223        # minimum required version for BitBake, as it seems not
224        # required (but harmless) with it.
225        asyncio.set_event_loop(self.loop)
226
227        self._add_methods("connect_tcp", "ping")
228
229    @abc.abstractmethod
230    def _get_async_client(self):
231        pass
232
233    def _get_downcall_wrapper(self, downcall):
234        def wrapper(*args, **kwargs):
235            return self.loop.run_until_complete(downcall(*args, **kwargs))
236
237        return wrapper
238
239    def _add_methods(self, *methods):
240        for m in methods:
241            downcall = getattr(self.client, m)
242            setattr(self, m, self._get_downcall_wrapper(downcall))
243
244    def connect_unix(self, path):
245        self.loop.run_until_complete(self.client.connect_unix(path))
246        self.loop.run_until_complete(self.client.connect())
247
248    @property
249    def max_chunk(self):
250        return self.client.max_chunk
251
252    @max_chunk.setter
253    def max_chunk(self, value):
254        self.client.max_chunk = value
255
256    def disconnect(self):
257        self.loop.run_until_complete(self.client.close())
258
259    def close(self):
260        if self.loop:
261            self.loop.run_until_complete(self.client.close())
262            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
263            self.loop.close()
264        self.loop = None
265
266    def __enter__(self):
267        return self
268
269    def __exit__(self, exc_type, exc_value, traceback):
270        self.close()
271        return False
272