1#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import socket
10from . import chunkify, DEFAULT_MAX_CHUNK
11
12
13class AsyncClient(object):
14    def __init__(self, proto_name, proto_version, logger, timeout=30):
15        self.reader = None
16        self.writer = None
17        self.max_chunk = DEFAULT_MAX_CHUNK
18        self.proto_name = proto_name
19        self.proto_version = proto_version
20        self.logger = logger
21        self.timeout = timeout
22
23    async def connect_tcp(self, address, port):
24        async def connect_sock():
25            return await asyncio.open_connection(address, port)
26
27        self._connect_sock = connect_sock
28
29    async def connect_unix(self, path):
30        async def connect_sock():
31            return await asyncio.open_unix_connection(path)
32
33        self._connect_sock = connect_sock
34
35    async def setup_connection(self):
36        s = '%s %s\n\n' % (self.proto_name, self.proto_version)
37        self.writer.write(s.encode("utf-8"))
38        await self.writer.drain()
39
40    async def connect(self):
41        if self.reader is None or self.writer is None:
42            (self.reader, self.writer) = await self._connect_sock()
43            await self.setup_connection()
44
45    async def close(self):
46        self.reader = None
47
48        if self.writer is not None:
49            self.writer.close()
50            self.writer = None
51
52    async def _send_wrapper(self, proc):
53        count = 0
54        while True:
55            try:
56                await self.connect()
57                return await proc()
58            except (
59                OSError,
60                ConnectionError,
61                json.JSONDecodeError,
62                UnicodeDecodeError,
63            ) as e:
64                self.logger.warning("Error talking to server: %s" % e)
65                if count >= 3:
66                    if not isinstance(e, ConnectionError):
67                        raise ConnectionError(str(e))
68                    raise e
69                await self.close()
70                count += 1
71
72    async def send_message(self, msg):
73        async def get_line():
74            try:
75                line = await asyncio.wait_for(self.reader.readline(), self.timeout)
76            except asyncio.TimeoutError:
77                raise ConnectionError("Timed out waiting for server")
78
79            if not line:
80                raise ConnectionError("Connection closed")
81
82            line = line.decode("utf-8")
83
84            if not line.endswith("\n"):
85                raise ConnectionError("Bad message %r" % (line))
86
87            return line
88
89        async def proc():
90            for c in chunkify(json.dumps(msg), self.max_chunk):
91                self.writer.write(c.encode("utf-8"))
92            await self.writer.drain()
93
94            l = await get_line()
95
96            m = json.loads(l)
97            if m and "chunk-stream" in m:
98                lines = []
99                while True:
100                    l = (await get_line()).rstrip("\n")
101                    if not l:
102                        break
103                    lines.append(l)
104
105                m = json.loads("".join(lines))
106
107            return m
108
109        return await self._send_wrapper(proc)
110
111    async def ping(self):
112        return await self.send_message(
113            {'ping': {}}
114        )
115
116
117class Client(object):
118    def __init__(self):
119        self.client = self._get_async_client()
120        self.loop = asyncio.new_event_loop()
121
122        self._add_methods('connect_tcp', 'close', 'ping')
123
124    @abc.abstractmethod
125    def _get_async_client(self):
126        pass
127
128    def _get_downcall_wrapper(self, downcall):
129        def wrapper(*args, **kwargs):
130            return self.loop.run_until_complete(downcall(*args, **kwargs))
131
132        return wrapper
133
134    def _add_methods(self, *methods):
135        for m in methods:
136            downcall = getattr(self.client, m)
137            setattr(self, m, self._get_downcall_wrapper(downcall))
138
139    def connect_unix(self, path):
140        # AF_UNIX has path length issues so chdir here to workaround
141        cwd = os.getcwd()
142        try:
143            os.chdir(os.path.dirname(path))
144            self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
145            self.loop.run_until_complete(self.client.connect())
146        finally:
147            os.chdir(cwd)
148
149    @property
150    def max_chunk(self):
151        return self.client.max_chunk
152
153    @max_chunk.setter
154    def max_chunk(self, value):
155        self.client.max_chunk = value
156