1# 2# SPDX-License-Identifier: GPL-2.0-only 3# 4 5import abc 6import asyncio 7import json 8import os 9import socket 10from . import chunkify, DEFAULT_MAX_CHUNK 11 12 13class AsyncClient(object): 14 def __init__(self, proto_name, proto_version, logger, timeout=30): 15 self.reader = None 16 self.writer = None 17 self.max_chunk = DEFAULT_MAX_CHUNK 18 self.proto_name = proto_name 19 self.proto_version = proto_version 20 self.logger = logger 21 self.timeout = timeout 22 23 async def connect_tcp(self, address, port): 24 async def connect_sock(): 25 return await asyncio.open_connection(address, port) 26 27 self._connect_sock = connect_sock 28 29 async def connect_unix(self, path): 30 async def connect_sock(): 31 return await asyncio.open_unix_connection(path) 32 33 self._connect_sock = connect_sock 34 35 async def setup_connection(self): 36 s = '%s %s\n\n' % (self.proto_name, self.proto_version) 37 self.writer.write(s.encode("utf-8")) 38 await self.writer.drain() 39 40 async def connect(self): 41 if self.reader is None or self.writer is None: 42 (self.reader, self.writer) = await self._connect_sock() 43 await self.setup_connection() 44 45 async def close(self): 46 self.reader = None 47 48 if self.writer is not None: 49 self.writer.close() 50 self.writer = None 51 52 async def _send_wrapper(self, proc): 53 count = 0 54 while True: 55 try: 56 await self.connect() 57 return await proc() 58 except ( 59 OSError, 60 ConnectionError, 61 json.JSONDecodeError, 62 UnicodeDecodeError, 63 ) as e: 64 self.logger.warning("Error talking to server: %s" % e) 65 if count >= 3: 66 if not isinstance(e, ConnectionError): 67 raise ConnectionError(str(e)) 68 raise e 69 await self.close() 70 count += 1 71 72 async def send_message(self, msg): 73 async def get_line(): 74 try: 75 line = await asyncio.wait_for(self.reader.readline(), self.timeout) 76 except asyncio.TimeoutError: 77 raise ConnectionError("Timed out waiting for server") 78 79 if not line: 80 raise ConnectionError("Connection closed") 81 82 line = line.decode("utf-8") 83 84 if not line.endswith("\n"): 85 raise ConnectionError("Bad message %r" % (line)) 86 87 return line 88 89 async def proc(): 90 for c in chunkify(json.dumps(msg), self.max_chunk): 91 self.writer.write(c.encode("utf-8")) 92 await self.writer.drain() 93 94 l = await get_line() 95 96 m = json.loads(l) 97 if m and "chunk-stream" in m: 98 lines = [] 99 while True: 100 l = (await get_line()).rstrip("\n") 101 if not l: 102 break 103 lines.append(l) 104 105 m = json.loads("".join(lines)) 106 107 return m 108 109 return await self._send_wrapper(proc) 110 111 async def ping(self): 112 return await self.send_message( 113 {'ping': {}} 114 ) 115 116 117class Client(object): 118 def __init__(self): 119 self.client = self._get_async_client() 120 self.loop = asyncio.new_event_loop() 121 122 self._add_methods('connect_tcp', 'close', 'ping') 123 124 @abc.abstractmethod 125 def _get_async_client(self): 126 pass 127 128 def _get_downcall_wrapper(self, downcall): 129 def wrapper(*args, **kwargs): 130 return self.loop.run_until_complete(downcall(*args, **kwargs)) 131 132 return wrapper 133 134 def _add_methods(self, *methods): 135 for m in methods: 136 downcall = getattr(self.client, m) 137 setattr(self, m, self._get_downcall_wrapper(downcall)) 138 139 def connect_unix(self, path): 140 # AF_UNIX has path length issues so chdir here to workaround 141 cwd = os.getcwd() 142 try: 143 os.chdir(os.path.dirname(path)) 144 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path))) 145 self.loop.run_until_complete(self.client.connect()) 146 finally: 147 os.chdir(cwd) 148 149 @property 150 def max_chunk(self): 151 return self.client.max_chunk 152 153 @max_chunk.setter 154 def max_chunk(self, value): 155 self.client.max_chunk = value 156