1# 2# Copyright BitBake Contributors 3# 4# SPDX-License-Identifier: GPL-2.0-only 5# 6 7import abc 8import asyncio 9import json 10import os 11import signal 12import socket 13import sys 14import multiprocessing 15import logging 16from .connection import StreamConnection, WebsocketConnection 17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError 18 19 20class ClientLoggerAdapter(logging.LoggerAdapter): 21 def process(self, msg, kwargs): 22 return f"[Client {self.extra['address']}] {msg}", kwargs 23 24 25class AsyncServerConnection(object): 26 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no 27 # return message will be automatically be sent back to the client 28 NO_RESPONSE = object() 29 30 def __init__(self, socket, proto_name, logger): 31 self.socket = socket 32 self.proto_name = proto_name 33 self.handlers = { 34 "ping": self.handle_ping, 35 } 36 self.logger = ClientLoggerAdapter( 37 logger, 38 { 39 "address": socket.address, 40 }, 41 ) 42 self.client_headers = {} 43 44 async def close(self): 45 await self.socket.close() 46 47 async def handle_headers(self, headers): 48 return {} 49 50 async def process_requests(self): 51 try: 52 self.logger.info("Client %r connected" % (self.socket.address,)) 53 54 # Read protocol and version 55 client_protocol = await self.socket.recv() 56 if not client_protocol: 57 return 58 59 (client_proto_name, client_proto_version) = client_protocol.split() 60 if client_proto_name != self.proto_name: 61 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name)) 62 return 63 64 self.proto_version = tuple(int(v) for v in client_proto_version.split(".")) 65 if not self.validate_proto_version(): 66 self.logger.debug( 67 "Rejecting invalid protocol version %s" % (client_proto_version) 68 ) 69 return 70 71 # Read headers 72 self.client_headers = {} 73 while True: 74 header = await self.socket.recv() 75 if not header: 76 # Empty line. End of headers 77 break 78 tag, value = header.split(":", 1) 79 self.client_headers[tag.lower()] = value.strip() 80 81 if self.client_headers.get("needs-headers", "false") == "true": 82 for k, v in (await self.handle_headers(self.client_headers)).items(): 83 await self.socket.send("%s: %s" % (k, v)) 84 await self.socket.send("") 85 86 # Handle messages 87 while True: 88 d = await self.socket.recv_message() 89 if d is None: 90 break 91 try: 92 response = await self.dispatch_message(d) 93 except InvokeError as e: 94 await self.socket.send_message( 95 {"invoke-error": {"message": str(e)}} 96 ) 97 break 98 99 if response is not self.NO_RESPONSE: 100 await self.socket.send_message(response) 101 102 except ConnectionClosedError as e: 103 self.logger.info(str(e)) 104 except (ClientError, ConnectionError) as e: 105 self.logger.error(str(e)) 106 finally: 107 await self.close() 108 109 async def dispatch_message(self, msg): 110 for k in self.handlers.keys(): 111 if k in msg: 112 self.logger.debug("Handling %s" % k) 113 return await self.handlers[k](msg[k]) 114 115 raise ClientError("Unrecognized command %r" % msg) 116 117 async def handle_ping(self, request): 118 return {"alive": True} 119 120 121class StreamServer(object): 122 def __init__(self, handler, logger): 123 self.handler = handler 124 self.logger = logger 125 self.closed = False 126 127 async def handle_stream_client(self, reader, writer): 128 # writer.transport.set_write_buffer_limits(0) 129 socket = StreamConnection(reader, writer, -1) 130 if self.closed: 131 await socket.close() 132 return 133 134 await self.handler(socket) 135 136 async def stop(self): 137 self.closed = True 138 139 140class TCPStreamServer(StreamServer): 141 def __init__(self, host, port, handler, logger, *, reuseport=False): 142 super().__init__(handler, logger) 143 self.host = host 144 self.port = port 145 self.reuseport = reuseport 146 147 def start(self, loop): 148 self.server = loop.run_until_complete( 149 asyncio.start_server( 150 self.handle_stream_client, 151 self.host, 152 self.port, 153 reuse_port=self.reuseport, 154 ) 155 ) 156 157 for s in self.server.sockets: 158 self.logger.debug("Listening on %r" % (s.getsockname(),)) 159 # Newer python does this automatically. Do it manually here for 160 # maximum compatibility 161 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 162 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 163 164 # Enable keep alives. This prevents broken client connections 165 # from persisting on the server for long periods of time. 166 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 167 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) 168 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) 169 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) 170 171 name = self.server.sockets[0].getsockname() 172 if self.server.sockets[0].family == socket.AF_INET6: 173 self.address = "[%s]:%d" % (name[0], name[1]) 174 else: 175 self.address = "%s:%d" % (name[0], name[1]) 176 177 return [self.server.wait_closed()] 178 179 async def stop(self): 180 await super().stop() 181 self.server.close() 182 183 def cleanup(self): 184 pass 185 186 187class UnixStreamServer(StreamServer): 188 def __init__(self, path, handler, logger): 189 super().__init__(handler, logger) 190 self.path = path 191 192 def start(self, loop): 193 cwd = os.getcwd() 194 try: 195 # Work around path length limits in AF_UNIX 196 os.chdir(os.path.dirname(self.path)) 197 self.server = loop.run_until_complete( 198 asyncio.start_unix_server( 199 self.handle_stream_client, os.path.basename(self.path) 200 ) 201 ) 202 finally: 203 os.chdir(cwd) 204 205 self.logger.debug("Listening on %r" % self.path) 206 self.address = "unix://%s" % os.path.abspath(self.path) 207 return [self.server.wait_closed()] 208 209 async def stop(self): 210 await super().stop() 211 self.server.close() 212 213 def cleanup(self): 214 os.unlink(self.path) 215 216 217class WebsocketsServer(object): 218 def __init__(self, host, port, handler, logger, *, reuseport=False): 219 self.host = host 220 self.port = port 221 self.handler = handler 222 self.logger = logger 223 self.reuseport = reuseport 224 225 def start(self, loop): 226 import websockets.server 227 228 self.server = loop.run_until_complete( 229 websockets.server.serve( 230 self.client_handler, 231 self.host, 232 self.port, 233 ping_interval=None, 234 reuse_port=self.reuseport, 235 ) 236 ) 237 238 for s in self.server.sockets: 239 self.logger.debug("Listening on %r" % (s.getsockname(),)) 240 241 # Enable keep alives. This prevents broken client connections 242 # from persisting on the server for long periods of time. 243 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 244 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) 245 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) 246 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) 247 248 name = self.server.sockets[0].getsockname() 249 if self.server.sockets[0].family == socket.AF_INET6: 250 self.address = "ws://[%s]:%d" % (name[0], name[1]) 251 else: 252 self.address = "ws://%s:%d" % (name[0], name[1]) 253 254 return [self.server.wait_closed()] 255 256 async def stop(self): 257 self.server.close() 258 259 def cleanup(self): 260 pass 261 262 async def client_handler(self, websocket): 263 socket = WebsocketConnection(websocket, -1) 264 await self.handler(socket) 265 266 267class AsyncServer(object): 268 def __init__(self, logger): 269 self.logger = logger 270 self.loop = None 271 self.run_tasks = [] 272 273 def start_tcp_server(self, host, port, *, reuseport=False): 274 self.server = TCPStreamServer( 275 host, 276 port, 277 self._client_handler, 278 self.logger, 279 reuseport=reuseport, 280 ) 281 282 def start_unix_server(self, path): 283 self.server = UnixStreamServer(path, self._client_handler, self.logger) 284 285 def start_websocket_server(self, host, port, reuseport=False): 286 self.server = WebsocketsServer( 287 host, 288 port, 289 self._client_handler, 290 self.logger, 291 reuseport=reuseport, 292 ) 293 294 async def _client_handler(self, socket): 295 address = socket.address 296 try: 297 client = self.accept_client(socket) 298 await client.process_requests() 299 except Exception as e: 300 import traceback 301 302 self.logger.error( 303 "Error from client %s: %s" % (address, str(e)), exc_info=True 304 ) 305 traceback.print_exc() 306 finally: 307 self.logger.debug("Client %s disconnected", address) 308 await socket.close() 309 310 @abc.abstractmethod 311 def accept_client(self, socket): 312 pass 313 314 async def stop(self): 315 self.logger.debug("Stopping server") 316 await self.server.stop() 317 318 def start(self): 319 tasks = self.server.start(self.loop) 320 self.address = self.server.address 321 return tasks 322 323 def signal_handler(self): 324 self.logger.debug("Got exit signal") 325 self.loop.create_task(self.stop()) 326 327 def _serve_forever(self, tasks): 328 try: 329 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 330 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler) 331 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler) 332 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 333 334 self.loop.run_until_complete(asyncio.gather(*tasks)) 335 336 self.logger.debug("Server shutting down") 337 finally: 338 self.server.cleanup() 339 340 def serve_forever(self): 341 """ 342 Serve requests in the current process 343 """ 344 self._create_loop() 345 tasks = self.start() 346 self._serve_forever(tasks) 347 self.loop.close() 348 349 def _create_loop(self): 350 # Create loop and override any loop that may have existed in 351 # a parent process. It is possible that the usecases of 352 # serve_forever might be constrained enough to allow using 353 # get_event_loop here, but better safe than sorry for now. 354 self.loop = asyncio.new_event_loop() 355 asyncio.set_event_loop(self.loop) 356 357 def serve_as_process(self, *, prefunc=None, args=(), log_level=None): 358 """ 359 Serve requests in a child process 360 """ 361 362 def run(queue): 363 # Create loop and override any loop that may have existed 364 # in a parent process. Without doing this and instead 365 # using get_event_loop, at the very minimum the hashserv 366 # unit tests will hang when running the second test. 367 # This happens since get_event_loop in the spawned server 368 # process for the second testcase ends up with the loop 369 # from the hashserv client created in the unit test process 370 # when running the first testcase. The problem is somewhat 371 # more general, though, as any potential use of asyncio in 372 # Cooker could create a loop that needs to replaced in this 373 # new process. 374 self._create_loop() 375 try: 376 self.address = None 377 tasks = self.start() 378 finally: 379 # Always put the server address to wake up the parent task 380 queue.put(self.address) 381 queue.close() 382 383 if prefunc is not None: 384 prefunc(self, *args) 385 386 if log_level is not None: 387 self.logger.setLevel(log_level) 388 389 self._serve_forever(tasks) 390 391 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 392 self.loop.close() 393 394 queue = multiprocessing.Queue() 395 396 # Temporarily block SIGTERM. The server process will inherit this 397 # block which will ensure it doesn't receive the SIGTERM until the 398 # handler is ready for it 399 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) 400 try: 401 self.process = multiprocessing.Process(target=run, args=(queue,)) 402 self.process.start() 403 404 self.address = queue.get() 405 queue.close() 406 queue.join_thread() 407 408 return self.process 409 finally: 410 signal.pthread_sigmask(signal.SIG_SETMASK, mask) 411