1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6import logging
7import socket
8import asyncio
9import bb.asyncrpc
10import json
11from . import create_async_client
12
13
14logger = logging.getLogger("hashserv.client")
15
16
17class Batch(object):
18    def __init__(self):
19        self.done = False
20        self.cond = asyncio.Condition()
21        self.pending = []
22        self.results = []
23        self.sent_count = 0
24
25    async def recv(self, socket):
26        while True:
27            async with self.cond:
28                await self.cond.wait_for(lambda: self.pending or self.done)
29
30                if not self.pending:
31                    if self.done:
32                        return
33                    continue
34
35            r = await socket.recv()
36            self.results.append(r)
37
38            async with self.cond:
39                self.pending.pop(0)
40
41    async def send(self, socket, msgs):
42        try:
43            # In the event of a restart due to a reconnect, all in-flight
44            # messages need to be resent first to keep to result count in sync
45            for m in self.pending:
46                await socket.send(m)
47
48            for m in msgs:
49                # Add the message to the pending list before attempting to send
50                # it so that if the send fails it will be retried
51                async with self.cond:
52                    self.pending.append(m)
53                    self.cond.notify()
54                    self.sent_count += 1
55
56                await socket.send(m)
57
58        finally:
59            async with self.cond:
60                self.done = True
61                self.cond.notify()
62
63    async def process(self, socket, msgs):
64        await asyncio.gather(
65            self.recv(socket),
66            self.send(socket, msgs),
67        )
68
69        if len(self.results) != self.sent_count:
70            raise ValueError(
71                f"Expected result count {len(self.results)}. Expected {self.sent_count}"
72            )
73
74        return self.results
75
76
77class AsyncClient(bb.asyncrpc.AsyncClient):
78    MODE_NORMAL = 0
79    MODE_GET_STREAM = 1
80    MODE_EXIST_STREAM = 2
81
82    def __init__(self, username=None, password=None):
83        super().__init__("OEHASHEQUIV", "1.1", logger)
84        self.mode = self.MODE_NORMAL
85        self.username = username
86        self.password = password
87        self.saved_become_user = None
88
89    async def setup_connection(self):
90        await super().setup_connection()
91        self.mode = self.MODE_NORMAL
92        if self.username:
93            # Save off become user temporarily because auth() resets it
94            become = self.saved_become_user
95            await self.auth(self.username, self.password)
96
97            if become:
98                await self.become_user(become)
99
100    async def send_stream_batch(self, mode, msgs):
101        """
102        Does a "batch" process of stream messages. This sends the query
103        messages as fast as possible, and simultaneously attempts to read the
104        messages back. This helps to mitigate the effects of latency to the
105        hash equivalence server be allowing multiple queries to be "in-flight"
106        at once
107
108        The implementation does more complicated tracking using a count of sent
109        messages so that `msgs` can be a generator function (i.e. its length is
110        unknown)
111
112        """
113
114        b = Batch()
115
116        async def proc():
117            nonlocal b
118
119            await self._set_mode(mode)
120            return await b.process(self.socket, msgs)
121
122        return await self._send_wrapper(proc)
123
124    async def invoke(self, *args, skip_mode=False, **kwargs):
125        # It's OK if connection errors cause a failure here, because the mode
126        # is also reset to normal on a new connection
127        if not skip_mode:
128            await self._set_mode(self.MODE_NORMAL)
129        return await super().invoke(*args, **kwargs)
130
131    async def _set_mode(self, new_mode):
132        async def stream_to_normal():
133            # Check if already in normal mode (e.g. due to a connection reset)
134            if self.mode == self.MODE_NORMAL:
135                return "ok"
136            await self.socket.send("END")
137            return await self.socket.recv()
138
139        async def normal_to_stream(command):
140            r = await self.invoke({command: None}, skip_mode=True)
141            if r != "ok":
142                self.check_invoke_error(r)
143                raise ConnectionError(
144                    f"Unable to transition to stream mode: Bad response from server {r!r}"
145                )
146            self.logger.debug("Mode is now %s", command)
147
148        if new_mode == self.mode:
149            return
150
151        self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode)
152
153        # Always transition to normal mode before switching to any other mode
154        if self.mode != self.MODE_NORMAL:
155            r = await self._send_wrapper(stream_to_normal)
156            if r != "ok":
157                self.check_invoke_error(r)
158                raise ConnectionError(
159                    f"Unable to transition to normal mode: Bad response from server {r!r}"
160                )
161            self.logger.debug("Mode is now normal")
162
163        if new_mode == self.MODE_GET_STREAM:
164            await normal_to_stream("get-stream")
165        elif new_mode == self.MODE_EXIST_STREAM:
166            await normal_to_stream("exists-stream")
167        elif new_mode != self.MODE_NORMAL:
168            raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
169
170        self.mode = new_mode
171
172    async def get_unihash(self, method, taskhash):
173        r = await self.get_unihash_batch([(method, taskhash)])
174        return r[0]
175
176    async def get_unihash_batch(self, args):
177        result = await self.send_stream_batch(
178            self.MODE_GET_STREAM,
179            (f"{method} {taskhash}" for method, taskhash in args),
180        )
181        return [r if r else None for r in result]
182
183    async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
184        m = extra.copy()
185        m["taskhash"] = taskhash
186        m["method"] = method
187        m["outhash"] = outhash
188        m["unihash"] = unihash
189        return await self.invoke({"report": m})
190
191    async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
192        m = extra.copy()
193        m["taskhash"] = taskhash
194        m["method"] = method
195        m["unihash"] = unihash
196        return await self.invoke({"report-equiv": m})
197
198    async def get_taskhash(self, method, taskhash, all_properties=False):
199        return await self.invoke(
200            {"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
201        )
202
203    async def unihash_exists(self, unihash):
204        r = await self.unihash_exists_batch([unihash])
205        return r[0]
206
207    async def unihash_exists_batch(self, unihashes):
208        result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes)
209        return [r == "true" for r in result]
210
211    async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
212        return await self.invoke(
213            {
214                "get-outhash": {
215                    "outhash": outhash,
216                    "taskhash": taskhash,
217                    "method": method,
218                    "with_unihash": with_unihash,
219                }
220            }
221        )
222
223    async def get_stats(self):
224        return await self.invoke({"get-stats": None})
225
226    async def reset_stats(self):
227        return await self.invoke({"reset-stats": None})
228
229    async def backfill_wait(self):
230        return (await self.invoke({"backfill-wait": None}))["tasks"]
231
232    async def remove(self, where):
233        return await self.invoke({"remove": {"where": where}})
234
235    async def clean_unused(self, max_age):
236        return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
237
238    async def auth(self, username, token):
239        result = await self.invoke({"auth": {"username": username, "token": token}})
240        self.username = username
241        self.password = token
242        self.saved_become_user = None
243        return result
244
245    async def refresh_token(self, username=None):
246        m = {}
247        if username:
248            m["username"] = username
249        result = await self.invoke({"refresh-token": m})
250        if (
251            self.username
252            and not self.saved_become_user
253            and result["username"] == self.username
254        ):
255            self.password = result["token"]
256        return result
257
258    async def set_user_perms(self, username, permissions):
259        return await self.invoke(
260            {"set-user-perms": {"username": username, "permissions": permissions}}
261        )
262
263    async def get_user(self, username=None):
264        m = {}
265        if username:
266            m["username"] = username
267        return await self.invoke({"get-user": m})
268
269    async def get_all_users(self):
270        return (await self.invoke({"get-all-users": {}}))["users"]
271
272    async def new_user(self, username, permissions):
273        return await self.invoke(
274            {"new-user": {"username": username, "permissions": permissions}}
275        )
276
277    async def delete_user(self, username):
278        return await self.invoke({"delete-user": {"username": username}})
279
280    async def become_user(self, username):
281        result = await self.invoke({"become-user": {"username": username}})
282        if username == self.username:
283            self.saved_become_user = None
284        else:
285            self.saved_become_user = username
286        return result
287
288    async def get_db_usage(self):
289        return (await self.invoke({"get-db-usage": {}}))["usage"]
290
291    async def get_db_query_columns(self):
292        return (await self.invoke({"get-db-query-columns": {}}))["columns"]
293
294    async def gc_status(self):
295        return await self.invoke({"gc-status": {}})
296
297    async def gc_mark(self, mark, where):
298        """
299        Starts a new garbage collection operation identified by "mark". If
300        garbage collection is already in progress with "mark", the collection
301        is continued.
302
303        All unihash entries that match the "where" clause are marked to be
304        kept. In addition, any new entries added to the database after this
305        command will be automatically marked with "mark"
306        """
307        return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
308
309    async def gc_sweep(self, mark):
310        """
311        Finishes garbage collection for "mark". All unihash entries that have
312        not been marked will be deleted.
313
314        It is recommended to clean unused outhash entries after running this to
315        cleanup any dangling outhashes
316        """
317        return await self.invoke({"gc-sweep": {"mark": mark}})
318
319
320class Client(bb.asyncrpc.Client):
321    def __init__(self, username=None, password=None):
322        self.username = username
323        self.password = password
324
325        super().__init__()
326        self._add_methods(
327            "connect_tcp",
328            "connect_websocket",
329            "get_unihash",
330            "get_unihash_batch",
331            "report_unihash",
332            "report_unihash_equiv",
333            "get_taskhash",
334            "unihash_exists",
335            "unihash_exists_batch",
336            "get_outhash",
337            "get_stats",
338            "reset_stats",
339            "backfill_wait",
340            "remove",
341            "clean_unused",
342            "auth",
343            "refresh_token",
344            "set_user_perms",
345            "get_user",
346            "get_all_users",
347            "new_user",
348            "delete_user",
349            "become_user",
350            "get_db_usage",
351            "get_db_query_columns",
352            "gc_status",
353            "gc_mark",
354            "gc_sweep",
355        )
356
357    def _get_async_client(self):
358        return AsyncClient(self.username, self.password)
359