1# 2# Copyright BitBake Contributors 3# 4# SPDX-License-Identifier: GPL-2.0-only 5# 6 7import abc 8import asyncio 9import json 10import os 11import socket 12import sys 13import re 14import contextlib 15from threading import Thread 16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK 17from .exceptions import ConnectionClosedError, InvokeError 18 19UNIX_PREFIX = "unix://" 20WS_PREFIX = "ws://" 21WSS_PREFIX = "wss://" 22 23ADDR_TYPE_UNIX = 0 24ADDR_TYPE_TCP = 1 25ADDR_TYPE_WS = 2 26 27def parse_address(addr): 28 if addr.startswith(UNIX_PREFIX): 29 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) 30 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX): 31 return (ADDR_TYPE_WS, (addr,)) 32 else: 33 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr) 34 if m is not None: 35 host = m.group("host") 36 port = m.group("port") 37 else: 38 host, port = addr.split(":") 39 40 return (ADDR_TYPE_TCP, (host, int(port))) 41 42class AsyncClient(object): 43 def __init__( 44 self, 45 proto_name, 46 proto_version, 47 logger, 48 timeout=30, 49 server_headers=False, 50 headers={}, 51 ): 52 self.socket = None 53 self.max_chunk = DEFAULT_MAX_CHUNK 54 self.proto_name = proto_name 55 self.proto_version = proto_version 56 self.logger = logger 57 self.timeout = timeout 58 self.needs_server_headers = server_headers 59 self.server_headers = {} 60 self.headers = headers 61 62 async def connect_tcp(self, address, port): 63 async def connect_sock(): 64 reader, writer = await asyncio.open_connection(address, port) 65 return StreamConnection(reader, writer, self.timeout, self.max_chunk) 66 67 self._connect_sock = connect_sock 68 69 async def connect_unix(self, path): 70 async def connect_sock(): 71 # AF_UNIX has path length issues so chdir here to workaround 72 cwd = os.getcwd() 73 try: 74 os.chdir(os.path.dirname(path)) 75 # The socket must be opened synchronously so that CWD doesn't get 76 # changed out from underneath us so we pass as a sock into asyncio 77 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) 78 sock.connect(os.path.basename(path)) 79 finally: 80 os.chdir(cwd) 81 reader, writer = await asyncio.open_unix_connection(sock=sock) 82 return StreamConnection(reader, writer, self.timeout, self.max_chunk) 83 84 self._connect_sock = connect_sock 85 86 async def connect_websocket(self, uri): 87 import websockets 88 89 async def connect_sock(): 90 websocket = await websockets.connect(uri, ping_interval=None) 91 return WebsocketConnection(websocket, self.timeout) 92 93 self._connect_sock = connect_sock 94 95 async def setup_connection(self): 96 # Send headers 97 await self.socket.send("%s %s" % (self.proto_name, self.proto_version)) 98 await self.socket.send( 99 "needs-headers: %s" % ("true" if self.needs_server_headers else "false") 100 ) 101 for k, v in self.headers.items(): 102 await self.socket.send("%s: %s" % (k, v)) 103 104 # End of headers 105 await self.socket.send("") 106 107 self.server_headers = {} 108 if self.needs_server_headers: 109 while True: 110 line = await self.socket.recv() 111 if not line: 112 # End headers 113 break 114 tag, value = line.split(":", 1) 115 self.server_headers[tag.lower()] = value.strip() 116 117 async def get_header(self, tag, default): 118 await self.connect() 119 return self.server_headers.get(tag, default) 120 121 async def connect(self): 122 if self.socket is None: 123 self.socket = await self._connect_sock() 124 await self.setup_connection() 125 126 async def disconnect(self): 127 if self.socket is not None: 128 await self.socket.close() 129 self.socket = None 130 131 async def close(self): 132 await self.disconnect() 133 134 async def _send_wrapper(self, proc): 135 count = 0 136 while True: 137 try: 138 await self.connect() 139 return await proc() 140 except ( 141 OSError, 142 ConnectionError, 143 ConnectionClosedError, 144 json.JSONDecodeError, 145 UnicodeDecodeError, 146 ) as e: 147 self.logger.warning("Error talking to server: %s" % e) 148 if count >= 3: 149 if not isinstance(e, ConnectionError): 150 raise ConnectionError(str(e)) 151 raise e 152 await self.close() 153 count += 1 154 155 def check_invoke_error(self, msg): 156 if isinstance(msg, dict) and "invoke-error" in msg: 157 raise InvokeError(msg["invoke-error"]["message"]) 158 159 async def invoke(self, msg): 160 async def proc(): 161 await self.socket.send_message(msg) 162 return await self.socket.recv_message() 163 164 result = await self._send_wrapper(proc) 165 self.check_invoke_error(result) 166 return result 167 168 async def ping(self): 169 return await self.invoke({"ping": {}}) 170 171 async def __aenter__(self): 172 return self 173 174 async def __aexit__(self, exc_type, exc_value, traceback): 175 await self.close() 176 177 178class Client(object): 179 def __init__(self): 180 self.client = self._get_async_client() 181 self.loop = asyncio.new_event_loop() 182 183 # Override any pre-existing loop. 184 # Without this, the PR server export selftest triggers a hang 185 # when running with Python 3.7. The drawback is that there is 186 # potential for issues if the PR and hash equiv (or some new) 187 # clients need to both be instantiated in the same process. 188 # This should be revisited if/when Python 3.9 becomes the 189 # minimum required version for BitBake, as it seems not 190 # required (but harmless) with it. 191 asyncio.set_event_loop(self.loop) 192 193 self._add_methods("connect_tcp", "ping") 194 195 @abc.abstractmethod 196 def _get_async_client(self): 197 pass 198 199 def _get_downcall_wrapper(self, downcall): 200 def wrapper(*args, **kwargs): 201 return self.loop.run_until_complete(downcall(*args, **kwargs)) 202 203 return wrapper 204 205 def _add_methods(self, *methods): 206 for m in methods: 207 downcall = getattr(self.client, m) 208 setattr(self, m, self._get_downcall_wrapper(downcall)) 209 210 def connect_unix(self, path): 211 self.loop.run_until_complete(self.client.connect_unix(path)) 212 self.loop.run_until_complete(self.client.connect()) 213 214 @property 215 def max_chunk(self): 216 return self.client.max_chunk 217 218 @max_chunk.setter 219 def max_chunk(self, value): 220 self.client.max_chunk = value 221 222 def disconnect(self): 223 self.loop.run_until_complete(self.client.close()) 224 225 def close(self): 226 if self.loop: 227 self.loop.run_until_complete(self.client.close()) 228 if sys.version_info >= (3, 6): 229 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 230 self.loop.close() 231 self.loop = None 232 233 def __enter__(self): 234 return self 235 236 def __exit__(self, exc_type, exc_value, traceback): 237 self.close() 238 return False 239 240 241class ClientPool(object): 242 def __init__(self, max_clients): 243 self.avail_clients = [] 244 self.num_clients = 0 245 self.max_clients = max_clients 246 self.loop = None 247 self.client_condition = None 248 249 @abc.abstractmethod 250 async def _new_client(self): 251 raise NotImplementedError("Must be implemented in derived class") 252 253 def close(self): 254 if self.client_condition: 255 self.client_condition = None 256 257 if self.loop: 258 self.loop.run_until_complete(self.__close_clients()) 259 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 260 self.loop.close() 261 self.loop = None 262 263 def run_tasks(self, tasks): 264 if not self.loop: 265 self.loop = asyncio.new_event_loop() 266 267 thread = Thread(target=self.__thread_main, args=(tasks,)) 268 thread.start() 269 thread.join() 270 271 @contextlib.asynccontextmanager 272 async def get_client(self): 273 async with self.client_condition: 274 if self.avail_clients: 275 client = self.avail_clients.pop() 276 elif self.num_clients < self.max_clients: 277 self.num_clients += 1 278 client = await self._new_client() 279 else: 280 while not self.avail_clients: 281 await self.client_condition.wait() 282 client = self.avail_clients.pop() 283 284 try: 285 yield client 286 finally: 287 async with self.client_condition: 288 self.avail_clients.append(client) 289 self.client_condition.notify() 290 291 def __thread_main(self, tasks): 292 async def process_task(task): 293 async with self.get_client() as client: 294 await task(client) 295 296 asyncio.set_event_loop(self.loop) 297 if not self.client_condition: 298 self.client_condition = asyncio.Condition() 299 tasks = [process_task(t) for t in tasks] 300 self.loop.run_until_complete(asyncio.gather(*tasks)) 301 302 async def __close_clients(self): 303 for c in self.avail_clients: 304 await c.close() 305 self.avail_clients = [] 306 self.num_clients = 0 307 308 def __enter__(self): 309 return self 310 311 def __exit__(self, exc_type, exc_value, traceback): 312 self.close() 313 return False 314