1# 2# Copyright BitBake Contributors 3# 4# SPDX-License-Identifier: GPL-2.0-only 5# 6 7import abc 8import asyncio 9import json 10import os 11import signal 12import socket 13import sys 14import multiprocessing 15import logging 16from .connection import StreamConnection, WebsocketConnection 17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError 18 19 20class ClientLoggerAdapter(logging.LoggerAdapter): 21 def process(self, msg, kwargs): 22 return f"[Client {self.extra['address']}] {msg}", kwargs 23 24 25class AsyncServerConnection(object): 26 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no 27 # return message will be automatically be sent back to the client 28 NO_RESPONSE = object() 29 30 def __init__(self, socket, proto_name, logger): 31 self.socket = socket 32 self.proto_name = proto_name 33 self.handlers = { 34 "ping": self.handle_ping, 35 } 36 self.logger = ClientLoggerAdapter( 37 logger, 38 { 39 "address": socket.address, 40 }, 41 ) 42 self.client_headers = {} 43 44 async def close(self): 45 await self.socket.close() 46 47 async def handle_headers(self, headers): 48 return {} 49 50 async def process_requests(self): 51 try: 52 self.logger.info("Client %r connected" % (self.socket.address,)) 53 54 # Read protocol and version 55 client_protocol = await self.socket.recv() 56 if not client_protocol: 57 return 58 59 (client_proto_name, client_proto_version) = client_protocol.split() 60 if client_proto_name != self.proto_name: 61 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name)) 62 return 63 64 self.proto_version = tuple(int(v) for v in client_proto_version.split(".")) 65 if not self.validate_proto_version(): 66 self.logger.debug( 67 "Rejecting invalid protocol version %s" % (client_proto_version) 68 ) 69 return 70 71 # Read headers 72 self.client_headers = {} 73 while True: 74 header = await self.socket.recv() 75 if not header: 76 # Empty line. End of headers 77 break 78 tag, value = header.split(":", 1) 79 self.client_headers[tag.lower()] = value.strip() 80 81 if self.client_headers.get("needs-headers", "false") == "true": 82 for k, v in (await self.handle_headers(self.client_headers)).items(): 83 await self.socket.send("%s: %s" % (k, v)) 84 await self.socket.send("") 85 86 # Handle messages 87 while True: 88 d = await self.socket.recv_message() 89 if d is None: 90 break 91 try: 92 response = await self.dispatch_message(d) 93 except InvokeError as e: 94 await self.socket.send_message( 95 {"invoke-error": {"message": str(e)}} 96 ) 97 break 98 99 if response is not self.NO_RESPONSE: 100 await self.socket.send_message(response) 101 102 except ConnectionClosedError as e: 103 self.logger.info(str(e)) 104 except (ClientError, ConnectionError) as e: 105 self.logger.error(str(e)) 106 finally: 107 await self.close() 108 109 async def dispatch_message(self, msg): 110 for k in self.handlers.keys(): 111 if k in msg: 112 self.logger.debug("Handling %s" % k) 113 return await self.handlers[k](msg[k]) 114 115 raise ClientError("Unrecognized command %r" % msg) 116 117 async def handle_ping(self, request): 118 return {"alive": True} 119 120 121class StreamServer(object): 122 def __init__(self, handler, logger): 123 self.handler = handler 124 self.logger = logger 125 self.closed = False 126 127 async def handle_stream_client(self, reader, writer): 128 # writer.transport.set_write_buffer_limits(0) 129 socket = StreamConnection(reader, writer, -1) 130 if self.closed: 131 await socket.close() 132 return 133 134 await self.handler(socket) 135 136 async def stop(self): 137 self.closed = True 138 139 140class TCPStreamServer(StreamServer): 141 def __init__(self, host, port, handler, logger): 142 super().__init__(handler, logger) 143 self.host = host 144 self.port = port 145 146 def start(self, loop): 147 self.server = loop.run_until_complete( 148 asyncio.start_server(self.handle_stream_client, self.host, self.port) 149 ) 150 151 for s in self.server.sockets: 152 self.logger.debug("Listening on %r" % (s.getsockname(),)) 153 # Newer python does this automatically. Do it manually here for 154 # maximum compatibility 155 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 156 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 157 158 # Enable keep alives. This prevents broken client connections 159 # from persisting on the server for long periods of time. 160 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 161 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) 162 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) 163 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) 164 165 name = self.server.sockets[0].getsockname() 166 if self.server.sockets[0].family == socket.AF_INET6: 167 self.address = "[%s]:%d" % (name[0], name[1]) 168 else: 169 self.address = "%s:%d" % (name[0], name[1]) 170 171 return [self.server.wait_closed()] 172 173 async def stop(self): 174 await super().stop() 175 self.server.close() 176 177 def cleanup(self): 178 pass 179 180 181class UnixStreamServer(StreamServer): 182 def __init__(self, path, handler, logger): 183 super().__init__(handler, logger) 184 self.path = path 185 186 def start(self, loop): 187 cwd = os.getcwd() 188 try: 189 # Work around path length limits in AF_UNIX 190 os.chdir(os.path.dirname(self.path)) 191 self.server = loop.run_until_complete( 192 asyncio.start_unix_server( 193 self.handle_stream_client, os.path.basename(self.path) 194 ) 195 ) 196 finally: 197 os.chdir(cwd) 198 199 self.logger.debug("Listening on %r" % self.path) 200 self.address = "unix://%s" % os.path.abspath(self.path) 201 return [self.server.wait_closed()] 202 203 async def stop(self): 204 await super().stop() 205 self.server.close() 206 207 def cleanup(self): 208 os.unlink(self.path) 209 210 211class WebsocketsServer(object): 212 def __init__(self, host, port, handler, logger): 213 self.host = host 214 self.port = port 215 self.handler = handler 216 self.logger = logger 217 218 def start(self, loop): 219 import websockets.server 220 221 self.server = loop.run_until_complete( 222 websockets.server.serve( 223 self.client_handler, 224 self.host, 225 self.port, 226 ping_interval=None, 227 ) 228 ) 229 230 for s in self.server.sockets: 231 self.logger.debug("Listening on %r" % (s.getsockname(),)) 232 233 # Enable keep alives. This prevents broken client connections 234 # from persisting on the server for long periods of time. 235 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 236 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) 237 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) 238 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) 239 240 name = self.server.sockets[0].getsockname() 241 if self.server.sockets[0].family == socket.AF_INET6: 242 self.address = "ws://[%s]:%d" % (name[0], name[1]) 243 else: 244 self.address = "ws://%s:%d" % (name[0], name[1]) 245 246 return [self.server.wait_closed()] 247 248 async def stop(self): 249 self.server.close() 250 251 def cleanup(self): 252 pass 253 254 async def client_handler(self, websocket): 255 socket = WebsocketConnection(websocket, -1) 256 await self.handler(socket) 257 258 259class AsyncServer(object): 260 def __init__(self, logger): 261 self.logger = logger 262 self.loop = None 263 self.run_tasks = [] 264 265 def start_tcp_server(self, host, port): 266 self.server = TCPStreamServer(host, port, self._client_handler, self.logger) 267 268 def start_unix_server(self, path): 269 self.server = UnixStreamServer(path, self._client_handler, self.logger) 270 271 def start_websocket_server(self, host, port): 272 self.server = WebsocketsServer(host, port, self._client_handler, self.logger) 273 274 async def _client_handler(self, socket): 275 address = socket.address 276 try: 277 client = self.accept_client(socket) 278 await client.process_requests() 279 except Exception as e: 280 import traceback 281 282 self.logger.error( 283 "Error from client %s: %s" % (address, str(e)), exc_info=True 284 ) 285 traceback.print_exc() 286 finally: 287 self.logger.debug("Client %s disconnected", address) 288 await socket.close() 289 290 @abc.abstractmethod 291 def accept_client(self, socket): 292 pass 293 294 async def stop(self): 295 self.logger.debug("Stopping server") 296 await self.server.stop() 297 298 def start(self): 299 tasks = self.server.start(self.loop) 300 self.address = self.server.address 301 return tasks 302 303 def signal_handler(self): 304 self.logger.debug("Got exit signal") 305 self.loop.create_task(self.stop()) 306 307 def _serve_forever(self, tasks): 308 try: 309 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 310 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler) 311 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler) 312 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 313 314 self.loop.run_until_complete(asyncio.gather(*tasks)) 315 316 self.logger.debug("Server shutting down") 317 finally: 318 self.server.cleanup() 319 320 def serve_forever(self): 321 """ 322 Serve requests in the current process 323 """ 324 self._create_loop() 325 tasks = self.start() 326 self._serve_forever(tasks) 327 self.loop.close() 328 329 def _create_loop(self): 330 # Create loop and override any loop that may have existed in 331 # a parent process. It is possible that the usecases of 332 # serve_forever might be constrained enough to allow using 333 # get_event_loop here, but better safe than sorry for now. 334 self.loop = asyncio.new_event_loop() 335 asyncio.set_event_loop(self.loop) 336 337 def serve_as_process(self, *, prefunc=None, args=(), log_level=None): 338 """ 339 Serve requests in a child process 340 """ 341 342 def run(queue): 343 # Create loop and override any loop that may have existed 344 # in a parent process. Without doing this and instead 345 # using get_event_loop, at the very minimum the hashserv 346 # unit tests will hang when running the second test. 347 # This happens since get_event_loop in the spawned server 348 # process for the second testcase ends up with the loop 349 # from the hashserv client created in the unit test process 350 # when running the first testcase. The problem is somewhat 351 # more general, though, as any potential use of asyncio in 352 # Cooker could create a loop that needs to replaced in this 353 # new process. 354 self._create_loop() 355 try: 356 self.address = None 357 tasks = self.start() 358 finally: 359 # Always put the server address to wake up the parent task 360 queue.put(self.address) 361 queue.close() 362 363 if prefunc is not None: 364 prefunc(self, *args) 365 366 if log_level is not None: 367 self.logger.setLevel(log_level) 368 369 self._serve_forever(tasks) 370 371 if sys.version_info >= (3, 6): 372 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 373 self.loop.close() 374 375 queue = multiprocessing.Queue() 376 377 # Temporarily block SIGTERM. The server process will inherit this 378 # block which will ensure it doesn't receive the SIGTERM until the 379 # handler is ready for it 380 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) 381 try: 382 self.process = multiprocessing.Process(target=run, args=(queue,)) 383 self.process.start() 384 385 self.address = queue.get() 386 queue.close() 387 queue.join_thread() 388 389 return self.process 390 finally: 391 signal.pthread_sigmask(signal.SIG_SETMASK, mask) 392