1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import asyncio
8import itertools
9import json
10from datetime import datetime
11from .exceptions import ClientError, ConnectionClosedError
12
13
14# The Python async server defaults to a 64K receive buffer, so we hardcode our
15# maximum chunk size. It would be better if the client and server reported to
16# each other what the maximum chunk sizes were, but that will slow down the
17# connection setup with a round trip delay so I'd rather not do that unless it
18# is necessary
19DEFAULT_MAX_CHUNK = 32 * 1024
20
21
22def chunkify(msg, max_chunk):
23    if len(msg) < max_chunk - 1:
24        yield "".join((msg, "\n"))
25    else:
26        yield "".join((json.dumps({"chunk-stream": None}), "\n"))
27
28        args = [iter(msg)] * (max_chunk - 1)
29        for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
30            yield "".join(itertools.chain(m, "\n"))
31        yield "\n"
32
33
34def json_serialize(obj):
35    if isinstance(obj, datetime):
36        return obj.isoformat()
37    raise TypeError("Type %s not serializeable" % type(obj))
38
39
40class StreamConnection(object):
41    def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
42        self.reader = reader
43        self.writer = writer
44        self.timeout = timeout
45        self.max_chunk = max_chunk
46
47    @property
48    def address(self):
49        return self.writer.get_extra_info("peername")
50
51    async def send_message(self, msg):
52        for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
53            self.writer.write(c.encode("utf-8"))
54        await self.writer.drain()
55
56    async def recv_message(self):
57        l = await self.recv()
58
59        m = json.loads(l)
60        if not m:
61            return m
62
63        if "chunk-stream" in m:
64            lines = []
65            while True:
66                l = await self.recv()
67                if not l:
68                    break
69                lines.append(l)
70
71            m = json.loads("".join(lines))
72
73        return m
74
75    async def send(self, msg):
76        self.writer.write(("%s\n" % msg).encode("utf-8"))
77        await self.writer.drain()
78
79    async def recv(self):
80        if self.timeout < 0:
81            line = await self.reader.readline()
82        else:
83            try:
84                line = await asyncio.wait_for(self.reader.readline(), self.timeout)
85            except asyncio.TimeoutError:
86                raise ConnectionError("Timed out waiting for data")
87
88        if not line:
89            raise ConnectionClosedError("Connection closed")
90
91        line = line.decode("utf-8")
92
93        if not line.endswith("\n"):
94            raise ConnectionError("Bad message %r" % (line))
95
96        return line.rstrip()
97
98    async def close(self):
99        self.reader = None
100        if self.writer is not None:
101            self.writer.close()
102            self.writer = None
103
104
105class WebsocketConnection(object):
106    def __init__(self, socket, timeout):
107        self.socket = socket
108        self.timeout = timeout
109
110    @property
111    def address(self):
112        return ":".join(str(s) for s in self.socket.remote_address)
113
114    async def send_message(self, msg):
115        await self.send(json.dumps(msg, default=json_serialize))
116
117    async def recv_message(self):
118        m = await self.recv()
119        return json.loads(m)
120
121    async def send(self, msg):
122        import websockets.exceptions
123
124        try:
125            await self.socket.send(msg)
126        except websockets.exceptions.ConnectionClosed:
127            raise ConnectionClosedError("Connection closed")
128
129    async def recv(self):
130        import websockets.exceptions
131
132        try:
133            if self.timeout < 0:
134                return await self.socket.recv()
135
136            try:
137                return await asyncio.wait_for(self.socket.recv(), self.timeout)
138            except asyncio.TimeoutError:
139                raise ConnectionError("Timed out waiting for data")
140        except websockets.exceptions.ConnectionClosed:
141            raise ConnectionClosedError("Connection closed")
142
143    async def close(self):
144        if self.socket is not None:
145            await self.socket.close()
146            self.socket = None
147