1# 2# SPDX-License-Identifier: GPL-2.0-only 3# 4 5import abc 6import asyncio 7import json 8import os 9import socket 10from . import chunkify, DEFAULT_MAX_CHUNK 11 12 13class AsyncClient(object): 14 def __init__(self, proto_name, proto_version, logger): 15 self.reader = None 16 self.writer = None 17 self.max_chunk = DEFAULT_MAX_CHUNK 18 self.proto_name = proto_name 19 self.proto_version = proto_version 20 self.logger = logger 21 22 async def connect_tcp(self, address, port): 23 async def connect_sock(): 24 return await asyncio.open_connection(address, port) 25 26 self._connect_sock = connect_sock 27 28 async def connect_unix(self, path): 29 async def connect_sock(): 30 return await asyncio.open_unix_connection(path) 31 32 self._connect_sock = connect_sock 33 34 async def setup_connection(self): 35 s = '%s %s\n\n' % (self.proto_name, self.proto_version) 36 self.writer.write(s.encode("utf-8")) 37 await self.writer.drain() 38 39 async def connect(self): 40 if self.reader is None or self.writer is None: 41 (self.reader, self.writer) = await self._connect_sock() 42 await self.setup_connection() 43 44 async def close(self): 45 self.reader = None 46 47 if self.writer is not None: 48 self.writer.close() 49 self.writer = None 50 51 async def _send_wrapper(self, proc): 52 count = 0 53 while True: 54 try: 55 await self.connect() 56 return await proc() 57 except ( 58 OSError, 59 ConnectionError, 60 json.JSONDecodeError, 61 UnicodeDecodeError, 62 ) as e: 63 self.logger.warning("Error talking to server: %s" % e) 64 if count >= 3: 65 if not isinstance(e, ConnectionError): 66 raise ConnectionError(str(e)) 67 raise e 68 await self.close() 69 count += 1 70 71 async def send_message(self, msg): 72 async def get_line(): 73 line = await self.reader.readline() 74 if not line: 75 raise ConnectionError("Connection closed") 76 77 line = line.decode("utf-8") 78 79 if not line.endswith("\n"): 80 raise ConnectionError("Bad message %r" % msg) 81 82 return line 83 84 async def proc(): 85 for c in chunkify(json.dumps(msg), self.max_chunk): 86 self.writer.write(c.encode("utf-8")) 87 await self.writer.drain() 88 89 l = await get_line() 90 91 m = json.loads(l) 92 if m and "chunk-stream" in m: 93 lines = [] 94 while True: 95 l = (await get_line()).rstrip("\n") 96 if not l: 97 break 98 lines.append(l) 99 100 m = json.loads("".join(lines)) 101 102 return m 103 104 return await self._send_wrapper(proc) 105 106 107class Client(object): 108 def __init__(self): 109 self.client = self._get_async_client() 110 self.loop = asyncio.new_event_loop() 111 112 self._add_methods('connect_tcp', 'close') 113 114 @abc.abstractmethod 115 def _get_async_client(self): 116 pass 117 118 def _get_downcall_wrapper(self, downcall): 119 def wrapper(*args, **kwargs): 120 return self.loop.run_until_complete(downcall(*args, **kwargs)) 121 122 return wrapper 123 124 def _add_methods(self, *methods): 125 for m in methods: 126 downcall = getattr(self.client, m) 127 setattr(self, m, self._get_downcall_wrapper(downcall)) 128 129 def connect_unix(self, path): 130 # AF_UNIX has path length issues so chdir here to workaround 131 cwd = os.getcwd() 132 try: 133 os.chdir(os.path.dirname(path)) 134 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path))) 135 self.loop.run_until_complete(self.client.connect()) 136 finally: 137 os.chdir(cwd) 138 139 @property 140 def max_chunk(self): 141 return self.client.max_chunk 142 143 @max_chunk.setter 144 def max_chunk(self, value): 145 self.client.max_chunk = value 146