xref: /openbmc/openbmc/poky/bitbake/lib/bb/asyncrpc/client.py (revision edff49234e31f23dc79f823473c9e286a21596c1)
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30    WEBSOCKETS_MIN_VERSION = (10, 0)
31
32
33def parse_address(addr):
34    if addr.startswith(UNIX_PREFIX):
35        return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
36    elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
37        return (ADDR_TYPE_WS, (addr,))
38    else:
39        m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
40        if m is not None:
41            host = m.group("host")
42            port = m.group("port")
43        else:
44            host, port = addr.split(":")
45
46        return (ADDR_TYPE_TCP, (host, int(port)))
47
48
49class AsyncClient(object):
50    def __init__(
51        self,
52        proto_name,
53        proto_version,
54        logger,
55        timeout=30,
56        server_headers=False,
57        headers={},
58    ):
59        self.socket = None
60        self.max_chunk = DEFAULT_MAX_CHUNK
61        self.proto_name = proto_name
62        self.proto_version = proto_version
63        self.logger = logger
64        self.timeout = timeout
65        self.needs_server_headers = server_headers
66        self.server_headers = {}
67        self.headers = headers
68
69    async def connect_tcp(self, address, port):
70        async def connect_sock():
71            reader, writer = await asyncio.open_connection(address, port)
72            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
73
74        self._connect_sock = connect_sock
75
76    async def connect_unix(self, path):
77        async def connect_sock():
78            # AF_UNIX has path length issues so chdir here to workaround
79            cwd = os.getcwd()
80            try:
81                os.chdir(os.path.dirname(path))
82                # The socket must be opened synchronously so that CWD doesn't get
83                # changed out from underneath us so we pass as a sock into asyncio
84                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
85                sock.connect(os.path.basename(path))
86            finally:
87                os.chdir(cwd)
88            reader, writer = await asyncio.open_unix_connection(sock=sock)
89            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
90
91        self._connect_sock = connect_sock
92
93    async def connect_websocket(self, uri):
94        import websockets
95
96        try:
97            version = tuple(
98                int(v)
99                for v in websockets.__version__.split(".")[
100                    0 : len(WEBSOCKETS_MIN_VERSION)
101                ]
102            )
103        except ValueError:
104            raise ImportError(
105                f"Unable to parse websockets version '{websockets.__version__}'"
106            )
107
108        if version < WEBSOCKETS_MIN_VERSION:
109            min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
110            raise ImportError(
111                f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
112            )
113
114        async def connect_sock():
115            websocket = await websockets.connect(
116                uri,
117                ping_interval=None,
118                open_timeout=self.timeout,
119            )
120            return WebsocketConnection(websocket, self.timeout)
121
122        self._connect_sock = connect_sock
123
124    async def setup_connection(self):
125        # Send headers
126        await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
127        await self.socket.send(
128            "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
129        )
130        for k, v in self.headers.items():
131            await self.socket.send("%s: %s" % (k, v))
132
133        # End of headers
134        await self.socket.send("")
135
136        self.server_headers = {}
137        if self.needs_server_headers:
138            while True:
139                line = await self.socket.recv()
140                if not line:
141                    # End headers
142                    break
143                tag, value = line.split(":", 1)
144                self.server_headers[tag.lower()] = value.strip()
145
146    async def get_header(self, tag, default):
147        await self.connect()
148        return self.server_headers.get(tag, default)
149
150    async def connect(self):
151        if self.socket is None:
152            self.socket = await self._connect_sock()
153            await self.setup_connection()
154
155    async def disconnect(self):
156        if self.socket is not None:
157            await self.socket.close()
158            self.socket = None
159
160    async def close(self):
161        await self.disconnect()
162
163    async def _send_wrapper(self, proc):
164        count = 0
165        while True:
166            try:
167                await self.connect()
168                return await proc()
169            except (
170                OSError,
171                ConnectionError,
172                ConnectionClosedError,
173                json.JSONDecodeError,
174                UnicodeDecodeError,
175            ) as e:
176                self.logger.warning("Error talking to server: %s" % e)
177                if count >= 3:
178                    if not isinstance(e, ConnectionError):
179                        raise ConnectionError(str(e))
180                    raise e
181                await self.close()
182                count += 1
183
184    def check_invoke_error(self, msg):
185        if isinstance(msg, dict) and "invoke-error" in msg:
186            raise InvokeError(msg["invoke-error"]["message"])
187
188    async def invoke(self, msg):
189        async def proc():
190            await self.socket.send_message(msg)
191            return await self.socket.recv_message()
192
193        result = await self._send_wrapper(proc)
194        self.check_invoke_error(result)
195        return result
196
197    async def ping(self):
198        return await self.invoke({"ping": {}})
199
200    async def __aenter__(self):
201        return self
202
203    async def __aexit__(self, exc_type, exc_value, traceback):
204        await self.close()
205
206
207class Client(object):
208    def __init__(self):
209        self.client = self._get_async_client()
210        self.loop = asyncio.new_event_loop()
211
212        # Override any pre-existing loop.
213        # Without this, the PR server export selftest triggers a hang
214        # when running with Python 3.7.  The drawback is that there is
215        # potential for issues if the PR and hash equiv (or some new)
216        # clients need to both be instantiated in the same process.
217        # This should be revisited if/when Python 3.9 becomes the
218        # minimum required version for BitBake, as it seems not
219        # required (but harmless) with it.
220        asyncio.set_event_loop(self.loop)
221
222        self._add_methods("connect_tcp", "ping")
223
224    @abc.abstractmethod
225    def _get_async_client(self):
226        pass
227
228    def _get_downcall_wrapper(self, downcall):
229        def wrapper(*args, **kwargs):
230            return self.loop.run_until_complete(downcall(*args, **kwargs))
231
232        return wrapper
233
234    def _add_methods(self, *methods):
235        for m in methods:
236            downcall = getattr(self.client, m)
237            setattr(self, m, self._get_downcall_wrapper(downcall))
238
239    def connect_unix(self, path):
240        self.loop.run_until_complete(self.client.connect_unix(path))
241        self.loop.run_until_complete(self.client.connect())
242
243    @property
244    def max_chunk(self):
245        return self.client.max_chunk
246
247    @max_chunk.setter
248    def max_chunk(self, value):
249        self.client.max_chunk = value
250
251    def disconnect(self):
252        self.loop.run_until_complete(self.client.close())
253
254    def close(self):
255        if self.loop:
256            self.loop.run_until_complete(self.client.close())
257            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
258            self.loop.close()
259        self.loop = None
260
261    def __enter__(self):
262        return self
263
264    def __exit__(self, exc_type, exc_value, traceback):
265        self.close()
266        return False
267