1# 2# Copyright BitBake Contributors 3# 4# SPDX-License-Identifier: GPL-2.0-only 5# 6 7import asyncio 8import itertools 9import json 10from datetime import datetime 11from .exceptions import ClientError, ConnectionClosedError 12 13 14# The Python async server defaults to a 64K receive buffer, so we hardcode our 15# maximum chunk size. It would be better if the client and server reported to 16# each other what the maximum chunk sizes were, but that will slow down the 17# connection setup with a round trip delay so I'd rather not do that unless it 18# is necessary 19DEFAULT_MAX_CHUNK = 32 * 1024 20 21 22def chunkify(msg, max_chunk): 23 if len(msg) < max_chunk - 1: 24 yield "".join((msg, "\n")) 25 else: 26 yield "".join((json.dumps({"chunk-stream": None}), "\n")) 27 28 args = [iter(msg)] * (max_chunk - 1) 29 for m in map("".join, itertools.zip_longest(*args, fillvalue="")): 30 yield "".join(itertools.chain(m, "\n")) 31 yield "\n" 32 33 34def json_serialize(obj): 35 if isinstance(obj, datetime): 36 return obj.isoformat() 37 raise TypeError("Type %s not serializeable" % type(obj)) 38 39 40class StreamConnection(object): 41 def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK): 42 self.reader = reader 43 self.writer = writer 44 self.timeout = timeout 45 self.max_chunk = max_chunk 46 47 @property 48 def address(self): 49 return self.writer.get_extra_info("peername") 50 51 async def send_message(self, msg): 52 for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk): 53 self.writer.write(c.encode("utf-8")) 54 await self.writer.drain() 55 56 async def recv_message(self): 57 l = await self.recv() 58 59 m = json.loads(l) 60 if not m: 61 return m 62 63 if "chunk-stream" in m: 64 lines = [] 65 while True: 66 l = await self.recv() 67 if not l: 68 break 69 lines.append(l) 70 71 m = json.loads("".join(lines)) 72 73 return m 74 75 async def send(self, msg): 76 self.writer.write(("%s\n" % msg).encode("utf-8")) 77 await self.writer.drain() 78 79 async def recv(self): 80 if self.timeout < 0: 81 line = await self.reader.readline() 82 else: 83 try: 84 line = await asyncio.wait_for(self.reader.readline(), self.timeout) 85 except asyncio.TimeoutError: 86 raise ConnectionError("Timed out waiting for data") 87 88 if not line: 89 raise ConnectionClosedError("Connection closed") 90 91 line = line.decode("utf-8") 92 93 if not line.endswith("\n"): 94 raise ConnectionError("Bad message %r" % (line)) 95 96 return line.rstrip() 97 98 async def close(self): 99 self.reader = None 100 if self.writer is not None: 101 self.writer.close() 102 self.writer = None 103 104 105class WebsocketConnection(object): 106 def __init__(self, socket, timeout): 107 self.socket = socket 108 self.timeout = timeout 109 110 @property 111 def address(self): 112 return ":".join(str(s) for s in self.socket.remote_address) 113 114 async def send_message(self, msg): 115 await self.send(json.dumps(msg, default=json_serialize)) 116 117 async def recv_message(self): 118 m = await self.recv() 119 return json.loads(m) 120 121 async def send(self, msg): 122 import websockets.exceptions 123 124 try: 125 await self.socket.send(msg) 126 except websockets.exceptions.ConnectionClosed: 127 raise ConnectionClosedError("Connection closed") 128 129 async def recv(self): 130 import websockets.exceptions 131 132 try: 133 if self.timeout < 0: 134 return await self.socket.recv() 135 136 try: 137 return await asyncio.wait_for(self.socket.recv(), self.timeout) 138 except asyncio.TimeoutError: 139 raise ConnectionError("Timed out waiting for data") 140 except websockets.exceptions.ConnectionClosed: 141 raise ConnectionClosedError("Connection closed") 142 143 async def close(self): 144 if self.socket is not None: 145 await self.socket.close() 146 self.socket = None 147