1# 2# Copyright BitBake Contributors 3# 4# SPDX-License-Identifier: GPL-2.0-only 5# 6 7import abc 8import asyncio 9import json 10import os 11import signal 12import socket 13import sys 14import multiprocessing 15from . import chunkify, DEFAULT_MAX_CHUNK 16 17 18class ClientError(Exception): 19 pass 20 21 22class ServerError(Exception): 23 pass 24 25 26class AsyncServerConnection(object): 27 def __init__(self, reader, writer, proto_name, logger): 28 self.reader = reader 29 self.writer = writer 30 self.proto_name = proto_name 31 self.max_chunk = DEFAULT_MAX_CHUNK 32 self.handlers = { 33 'chunk-stream': self.handle_chunk, 34 'ping': self.handle_ping, 35 } 36 self.logger = logger 37 38 async def process_requests(self): 39 try: 40 self.addr = self.writer.get_extra_info('peername') 41 self.logger.debug('Client %r connected' % (self.addr,)) 42 43 # Read protocol and version 44 client_protocol = await self.reader.readline() 45 if client_protocol is None: 46 return 47 48 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() 49 if client_proto_name != self.proto_name: 50 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) 51 return 52 53 self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) 54 if not self.validate_proto_version(): 55 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) 56 return 57 58 # Read headers. Currently, no headers are implemented, so look for 59 # an empty line to signal the end of the headers 60 while True: 61 line = await self.reader.readline() 62 if line is None: 63 return 64 65 line = line.decode('utf-8').rstrip() 66 if not line: 67 break 68 69 # Handle messages 70 while True: 71 d = await self.read_message() 72 if d is None: 73 break 74 await self.dispatch_message(d) 75 await self.writer.drain() 76 except ClientError as e: 77 self.logger.error(str(e)) 78 finally: 79 self.writer.close() 80 81 async def dispatch_message(self, msg): 82 for k in self.handlers.keys(): 83 if k in msg: 84 self.logger.debug('Handling %s' % k) 85 await self.handlers[k](msg[k]) 86 return 87 88 raise ClientError("Unrecognized command %r" % msg) 89 90 def write_message(self, msg): 91 for c in chunkify(json.dumps(msg), self.max_chunk): 92 self.writer.write(c.encode('utf-8')) 93 94 async def read_message(self): 95 l = await self.reader.readline() 96 if not l: 97 return None 98 99 try: 100 message = l.decode('utf-8') 101 102 if not message.endswith('\n'): 103 return None 104 105 return json.loads(message) 106 except (json.JSONDecodeError, UnicodeDecodeError) as e: 107 self.logger.error('Bad message from client: %r' % message) 108 raise e 109 110 async def handle_chunk(self, request): 111 lines = [] 112 try: 113 while True: 114 l = await self.reader.readline() 115 l = l.rstrip(b"\n").decode("utf-8") 116 if not l: 117 break 118 lines.append(l) 119 120 msg = json.loads(''.join(lines)) 121 except (json.JSONDecodeError, UnicodeDecodeError) as e: 122 self.logger.error('Bad message from client: %r' % lines) 123 raise e 124 125 if 'chunk-stream' in msg: 126 raise ClientError("Nested chunks are not allowed") 127 128 await self.dispatch_message(msg) 129 130 async def handle_ping(self, request): 131 response = {'alive': True} 132 self.write_message(response) 133 134 135class AsyncServer(object): 136 def __init__(self, logger): 137 self._cleanup_socket = None 138 self.logger = logger 139 self.start = None 140 self.address = None 141 self.loop = None 142 143 def start_tcp_server(self, host, port): 144 def start_tcp(): 145 self.server = self.loop.run_until_complete( 146 asyncio.start_server(self.handle_client, host, port) 147 ) 148 149 for s in self.server.sockets: 150 self.logger.debug('Listening on %r' % (s.getsockname(),)) 151 # Newer python does this automatically. Do it manually here for 152 # maximum compatibility 153 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 154 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 155 156 # Enable keep alives. This prevents broken client connections 157 # from persisting on the server for long periods of time. 158 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 159 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) 160 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) 161 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) 162 163 name = self.server.sockets[0].getsockname() 164 if self.server.sockets[0].family == socket.AF_INET6: 165 self.address = "[%s]:%d" % (name[0], name[1]) 166 else: 167 self.address = "%s:%d" % (name[0], name[1]) 168 169 self.start = start_tcp 170 171 def start_unix_server(self, path): 172 def cleanup(): 173 os.unlink(path) 174 175 def start_unix(): 176 cwd = os.getcwd() 177 try: 178 # Work around path length limits in AF_UNIX 179 os.chdir(os.path.dirname(path)) 180 self.server = self.loop.run_until_complete( 181 asyncio.start_unix_server(self.handle_client, os.path.basename(path)) 182 ) 183 finally: 184 os.chdir(cwd) 185 186 self.logger.debug('Listening on %r' % path) 187 188 self._cleanup_socket = cleanup 189 self.address = "unix://%s" % os.path.abspath(path) 190 191 self.start = start_unix 192 193 @abc.abstractmethod 194 def accept_client(self, reader, writer): 195 pass 196 197 async def handle_client(self, reader, writer): 198 # writer.transport.set_write_buffer_limits(0) 199 try: 200 client = self.accept_client(reader, writer) 201 await client.process_requests() 202 except Exception as e: 203 import traceback 204 self.logger.error('Error from client: %s' % str(e), exc_info=True) 205 traceback.print_exc() 206 writer.close() 207 self.logger.debug('Client disconnected') 208 209 def run_loop_forever(self): 210 try: 211 self.loop.run_forever() 212 except KeyboardInterrupt: 213 pass 214 215 def signal_handler(self): 216 self.logger.debug("Got exit signal") 217 self.loop.stop() 218 219 def _serve_forever(self): 220 try: 221 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 222 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 223 224 self.run_loop_forever() 225 self.server.close() 226 227 self.loop.run_until_complete(self.server.wait_closed()) 228 self.logger.debug('Server shutting down') 229 finally: 230 if self._cleanup_socket is not None: 231 self._cleanup_socket() 232 233 def serve_forever(self): 234 """ 235 Serve requests in the current process 236 """ 237 # Create loop and override any loop that may have existed in 238 # a parent process. It is possible that the usecases of 239 # serve_forever might be constrained enough to allow using 240 # get_event_loop here, but better safe than sorry for now. 241 self.loop = asyncio.new_event_loop() 242 asyncio.set_event_loop(self.loop) 243 self.start() 244 self._serve_forever() 245 246 def serve_as_process(self, *, prefunc=None, args=()): 247 """ 248 Serve requests in a child process 249 """ 250 def run(queue): 251 # Create loop and override any loop that may have existed 252 # in a parent process. Without doing this and instead 253 # using get_event_loop, at the very minimum the hashserv 254 # unit tests will hang when running the second test. 255 # This happens since get_event_loop in the spawned server 256 # process for the second testcase ends up with the loop 257 # from the hashserv client created in the unit test process 258 # when running the first testcase. The problem is somewhat 259 # more general, though, as any potential use of asyncio in 260 # Cooker could create a loop that needs to replaced in this 261 # new process. 262 self.loop = asyncio.new_event_loop() 263 asyncio.set_event_loop(self.loop) 264 try: 265 self.start() 266 finally: 267 queue.put(self.address) 268 queue.close() 269 270 if prefunc is not None: 271 prefunc(self, *args) 272 273 self._serve_forever() 274 275 if sys.version_info >= (3, 6): 276 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 277 self.loop.close() 278 279 queue = multiprocessing.Queue() 280 281 # Temporarily block SIGTERM. The server process will inherit this 282 # block which will ensure it doesn't receive the SIGTERM until the 283 # handler is ready for it 284 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) 285 try: 286 self.process = multiprocessing.Process(target=run, args=(queue,)) 287 self.process.start() 288 289 self.address = queue.get() 290 queue.close() 291 queue.join_thread() 292 293 return self.process 294 finally: 295 signal.pthread_sigmask(signal.SIG_SETMASK, mask) 296