1#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import socket
10from . import chunkify, DEFAULT_MAX_CHUNK
11
12
13class AsyncClient(object):
14    def __init__(self, proto_name, proto_version, logger):
15        self.reader = None
16        self.writer = None
17        self.max_chunk = DEFAULT_MAX_CHUNK
18        self.proto_name = proto_name
19        self.proto_version = proto_version
20        self.logger = logger
21
22    async def connect_tcp(self, address, port):
23        async def connect_sock():
24            return await asyncio.open_connection(address, port)
25
26        self._connect_sock = connect_sock
27
28    async def connect_unix(self, path):
29        async def connect_sock():
30            return await asyncio.open_unix_connection(path)
31
32        self._connect_sock = connect_sock
33
34    async def setup_connection(self):
35        s = '%s %s\n\n' % (self.proto_name, self.proto_version)
36        self.writer.write(s.encode("utf-8"))
37        await self.writer.drain()
38
39    async def connect(self):
40        if self.reader is None or self.writer is None:
41            (self.reader, self.writer) = await self._connect_sock()
42            await self.setup_connection()
43
44    async def close(self):
45        self.reader = None
46
47        if self.writer is not None:
48            self.writer.close()
49            self.writer = None
50
51    async def _send_wrapper(self, proc):
52        count = 0
53        while True:
54            try:
55                await self.connect()
56                return await proc()
57            except (
58                OSError,
59                ConnectionError,
60                json.JSONDecodeError,
61                UnicodeDecodeError,
62            ) as e:
63                self.logger.warning("Error talking to server: %s" % e)
64                if count >= 3:
65                    if not isinstance(e, ConnectionError):
66                        raise ConnectionError(str(e))
67                    raise e
68                await self.close()
69                count += 1
70
71    async def send_message(self, msg):
72        async def get_line():
73            line = await self.reader.readline()
74            if not line:
75                raise ConnectionError("Connection closed")
76
77            line = line.decode("utf-8")
78
79            if not line.endswith("\n"):
80                raise ConnectionError("Bad message %r" % msg)
81
82            return line
83
84        async def proc():
85            for c in chunkify(json.dumps(msg), self.max_chunk):
86                self.writer.write(c.encode("utf-8"))
87            await self.writer.drain()
88
89            l = await get_line()
90
91            m = json.loads(l)
92            if m and "chunk-stream" in m:
93                lines = []
94                while True:
95                    l = (await get_line()).rstrip("\n")
96                    if not l:
97                        break
98                    lines.append(l)
99
100                m = json.loads("".join(lines))
101
102            return m
103
104        return await self._send_wrapper(proc)
105
106
107class Client(object):
108    def __init__(self):
109        self.client = self._get_async_client()
110        self.loop = asyncio.new_event_loop()
111
112        self._add_methods('connect_tcp', 'close')
113
114    @abc.abstractmethod
115    def _get_async_client(self):
116        pass
117
118    def _get_downcall_wrapper(self, downcall):
119        def wrapper(*args, **kwargs):
120            return self.loop.run_until_complete(downcall(*args, **kwargs))
121
122        return wrapper
123
124    def _add_methods(self, *methods):
125        for m in methods:
126            downcall = getattr(self.client, m)
127            setattr(self, m, self._get_downcall_wrapper(downcall))
128
129    def connect_unix(self, path):
130        # AF_UNIX has path length issues so chdir here to workaround
131        cwd = os.getcwd()
132        try:
133            os.chdir(os.path.dirname(path))
134            self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
135            self.loop.run_until_complete(self.client.connect())
136        finally:
137            os.chdir(cwd)
138
139    @property
140    def max_chunk(self):
141        return self.client.max_chunk
142
143    @max_chunk.setter
144    def max_chunk(self, value):
145        self.client.max_chunk = value
146