1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6import logging
7import socket
8import bb.asyncrpc
9import json
10from . import create_async_client
11
12
13logger = logging.getLogger("hashserv.client")
14
15
16class AsyncClient(bb.asyncrpc.AsyncClient):
17    MODE_NORMAL = 0
18    MODE_GET_STREAM = 1
19    MODE_EXIST_STREAM = 2
20
21    def __init__(self, username=None, password=None):
22        super().__init__("OEHASHEQUIV", "1.1", logger)
23        self.mode = self.MODE_NORMAL
24        self.username = username
25        self.password = password
26        self.saved_become_user = None
27
28    async def setup_connection(self):
29        await super().setup_connection()
30        self.mode = self.MODE_NORMAL
31        if self.username:
32            # Save off become user temporarily because auth() resets it
33            become = self.saved_become_user
34            await self.auth(self.username, self.password)
35
36            if become:
37                await self.become_user(become)
38
39    async def send_stream(self, mode, msg):
40        async def proc():
41            await self._set_mode(mode)
42            await self.socket.send(msg)
43            return await self.socket.recv()
44
45        return await self._send_wrapper(proc)
46
47    async def invoke(self, *args, **kwargs):
48        # It's OK if connection errors cause a failure here, because the mode
49        # is also reset to normal on a new connection
50        await self._set_mode(self.MODE_NORMAL)
51        return await super().invoke(*args, **kwargs)
52
53    async def _set_mode(self, new_mode):
54        async def stream_to_normal():
55            await self.socket.send("END")
56            return await self.socket.recv()
57
58        async def normal_to_stream(command):
59            r = await self.invoke({command: None})
60            if r != "ok":
61                raise ConnectionError(
62                    f"Unable to transition to stream mode: Bad response from server {r!r}"
63                )
64
65            self.logger.debug("Mode is now %s", command)
66
67        if new_mode == self.mode:
68            return
69
70        self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode)
71
72        # Always transition to normal mode before switching to any other mode
73        if self.mode != self.MODE_NORMAL:
74            r = await self._send_wrapper(stream_to_normal)
75            if r != "ok":
76                self.check_invoke_error(r)
77                raise ConnectionError(
78                    f"Unable to transition to normal mode: Bad response from server {r!r}"
79                )
80            self.logger.debug("Mode is now normal")
81
82        if new_mode == self.MODE_GET_STREAM:
83            await normal_to_stream("get-stream")
84        elif new_mode == self.MODE_EXIST_STREAM:
85            await normal_to_stream("exists-stream")
86        elif new_mode != self.MODE_NORMAL:
87            raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
88
89        self.mode = new_mode
90
91    async def get_unihash(self, method, taskhash):
92        r = await self.send_stream(self.MODE_GET_STREAM, "%s %s" % (method, taskhash))
93        if not r:
94            return None
95        return r
96
97    async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
98        m = extra.copy()
99        m["taskhash"] = taskhash
100        m["method"] = method
101        m["outhash"] = outhash
102        m["unihash"] = unihash
103        return await self.invoke({"report": m})
104
105    async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
106        m = extra.copy()
107        m["taskhash"] = taskhash
108        m["method"] = method
109        m["unihash"] = unihash
110        return await self.invoke({"report-equiv": m})
111
112    async def get_taskhash(self, method, taskhash, all_properties=False):
113        return await self.invoke(
114            {"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
115        )
116
117    async def unihash_exists(self, unihash):
118        r = await self.send_stream(self.MODE_EXIST_STREAM, unihash)
119        return r == "true"
120
121    async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
122        return await self.invoke(
123            {
124                "get-outhash": {
125                    "outhash": outhash,
126                    "taskhash": taskhash,
127                    "method": method,
128                    "with_unihash": with_unihash,
129                }
130            }
131        )
132
133    async def get_stats(self):
134        return await self.invoke({"get-stats": None})
135
136    async def reset_stats(self):
137        return await self.invoke({"reset-stats": None})
138
139    async def backfill_wait(self):
140        return (await self.invoke({"backfill-wait": None}))["tasks"]
141
142    async def remove(self, where):
143        return await self.invoke({"remove": {"where": where}})
144
145    async def clean_unused(self, max_age):
146        return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
147
148    async def auth(self, username, token):
149        result = await self.invoke({"auth": {"username": username, "token": token}})
150        self.username = username
151        self.password = token
152        self.saved_become_user = None
153        return result
154
155    async def refresh_token(self, username=None):
156        m = {}
157        if username:
158            m["username"] = username
159        result = await self.invoke({"refresh-token": m})
160        if (
161            self.username
162            and not self.saved_become_user
163            and result["username"] == self.username
164        ):
165            self.password = result["token"]
166        return result
167
168    async def set_user_perms(self, username, permissions):
169        return await self.invoke(
170            {"set-user-perms": {"username": username, "permissions": permissions}}
171        )
172
173    async def get_user(self, username=None):
174        m = {}
175        if username:
176            m["username"] = username
177        return await self.invoke({"get-user": m})
178
179    async def get_all_users(self):
180        return (await self.invoke({"get-all-users": {}}))["users"]
181
182    async def new_user(self, username, permissions):
183        return await self.invoke(
184            {"new-user": {"username": username, "permissions": permissions}}
185        )
186
187    async def delete_user(self, username):
188        return await self.invoke({"delete-user": {"username": username}})
189
190    async def become_user(self, username):
191        result = await self.invoke({"become-user": {"username": username}})
192        if username == self.username:
193            self.saved_become_user = None
194        else:
195            self.saved_become_user = username
196        return result
197
198    async def get_db_usage(self):
199        return (await self.invoke({"get-db-usage": {}}))["usage"]
200
201    async def get_db_query_columns(self):
202        return (await self.invoke({"get-db-query-columns": {}}))["columns"]
203
204    async def gc_status(self):
205        return await self.invoke({"gc-status": {}})
206
207    async def gc_mark(self, mark, where):
208        """
209        Starts a new garbage collection operation identified by "mark". If
210        garbage collection is already in progress with "mark", the collection
211        is continued.
212
213        All unihash entries that match the "where" clause are marked to be
214        kept. In addition, any new entries added to the database after this
215        command will be automatically marked with "mark"
216        """
217        return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
218
219    async def gc_sweep(self, mark):
220        """
221        Finishes garbage collection for "mark". All unihash entries that have
222        not been marked will be deleted.
223
224        It is recommended to clean unused outhash entries after running this to
225        cleanup any dangling outhashes
226        """
227        return await self.invoke({"gc-sweep": {"mark": mark}})
228
229
230class Client(bb.asyncrpc.Client):
231    def __init__(self, username=None, password=None):
232        self.username = username
233        self.password = password
234
235        super().__init__()
236        self._add_methods(
237            "connect_tcp",
238            "connect_websocket",
239            "get_unihash",
240            "report_unihash",
241            "report_unihash_equiv",
242            "get_taskhash",
243            "unihash_exists",
244            "get_outhash",
245            "get_stats",
246            "reset_stats",
247            "backfill_wait",
248            "remove",
249            "clean_unused",
250            "auth",
251            "refresh_token",
252            "set_user_perms",
253            "get_user",
254            "get_all_users",
255            "new_user",
256            "delete_user",
257            "become_user",
258            "get_db_usage",
259            "get_db_query_columns",
260            "gc_status",
261            "gc_mark",
262            "gc_sweep",
263        )
264
265    def _get_async_client(self):
266        return AsyncClient(self.username, self.password)
267
268
269class ClientPool(bb.asyncrpc.ClientPool):
270    def __init__(
271        self,
272        address,
273        max_clients,
274        *,
275        username=None,
276        password=None,
277        become=None,
278    ):
279        super().__init__(max_clients)
280        self.address = address
281        self.username = username
282        self.password = password
283        self.become = become
284
285    async def _new_client(self):
286        client = await create_async_client(
287            self.address,
288            username=self.username,
289            password=self.password,
290        )
291        if self.become:
292            await client.become_user(self.become)
293        return client
294
295    def _run_key_tasks(self, queries, call):
296        results = {key: None for key in queries.keys()}
297
298        def make_task(key, args):
299            async def task(client):
300                nonlocal results
301                unihash = await call(client, args)
302                results[key] = unihash
303
304            return task
305
306        def gen_tasks():
307            for key, args in queries.items():
308                yield make_task(key, args)
309
310        self.run_tasks(gen_tasks())
311        return results
312
313    def get_unihashes(self, queries):
314        """
315        Query multiple unihashes in parallel.
316
317        The queries argument is a dictionary with arbitrary key. The values
318        must be a tuple of (method, taskhash).
319
320        Returns a dictionary with a corresponding key for each input key, and
321        the value is the queried unihash (which might be none if the query
322        failed)
323        """
324
325        async def call(client, args):
326            method, taskhash = args
327            return await client.get_unihash(method, taskhash)
328
329        return self._run_key_tasks(queries, call)
330
331    def unihashes_exist(self, queries):
332        """
333        Query multiple unihash existence checks in parallel.
334
335        The queries argument is a dictionary with arbitrary key. The values
336        must be a unihash.
337
338        Returns a dictionary with a corresponding key for each input key, and
339        the value is True or False if the unihash is known by the server (or
340        None if there was a failure)
341        """
342
343        async def call(client, unihash):
344            return await client.unihash_exists(unihash)
345
346        return self._run_key_tasks(queries, call)
347