xref: /openbmc/openbmc/poky/bitbake/lib/hashserv/client.py (revision c9537f57ab488bf5d90132917b0184e2527970a5)
1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6import logging
7import socket
8import asyncio
9import bb.asyncrpc
10import json
11from . import create_async_client
12
13
14logger = logging.getLogger("hashserv.client")
15
16
17class Batch(object):
18    def __init__(self):
19        self.done = False
20        self.cond = asyncio.Condition()
21        self.pending = []
22        self.results = []
23        self.sent_count = 0
24
25    async def recv(self, socket):
26        while True:
27            async with self.cond:
28                await self.cond.wait_for(lambda: self.pending or self.done)
29
30                if not self.pending:
31                    if self.done:
32                        return
33                    continue
34
35            r = await socket.recv()
36            self.results.append(r)
37
38            async with self.cond:
39                self.pending.pop(0)
40
41    async def send(self, socket, msgs):
42        try:
43            # In the event of a restart due to a reconnect, all in-flight
44            # messages need to be resent first to keep to result count in sync
45            for m in self.pending:
46                await socket.send(m)
47
48            for m in msgs:
49                # Add the message to the pending list before attempting to send
50                # it so that if the send fails it will be retried
51                async with self.cond:
52                    self.pending.append(m)
53                    self.cond.notify()
54                    self.sent_count += 1
55
56                await socket.send(m)
57
58        finally:
59            async with self.cond:
60                self.done = True
61                self.cond.notify()
62
63    async def process(self, socket, msgs):
64        await asyncio.gather(
65            self.recv(socket),
66            self.send(socket, msgs),
67        )
68
69        if len(self.results) != self.sent_count:
70            raise ValueError(
71                f"Expected result count {len(self.results)}. Expected {self.sent_count}"
72            )
73
74        return self.results
75
76
77class AsyncClient(bb.asyncrpc.AsyncClient):
78    MODE_NORMAL = 0
79    MODE_GET_STREAM = 1
80    MODE_EXIST_STREAM = 2
81    MODE_MARK_STREAM = 3
82
83    def __init__(self, username=None, password=None):
84        super().__init__("OEHASHEQUIV", "1.1", logger)
85        self.mode = self.MODE_NORMAL
86        self.username = username
87        self.password = password
88        self.saved_become_user = None
89
90    async def setup_connection(self):
91        await super().setup_connection()
92        self.mode = self.MODE_NORMAL
93        if self.username:
94            # Save off become user temporarily because auth() resets it
95            become = self.saved_become_user
96            await self.auth(self.username, self.password)
97
98            if become:
99                await self.become_user(become)
100
101    async def send_stream_batch(self, mode, msgs):
102        """
103        Does a "batch" process of stream messages. This sends the query
104        messages as fast as possible, and simultaneously attempts to read the
105        messages back. This helps to mitigate the effects of latency to the
106        hash equivalence server be allowing multiple queries to be "in-flight"
107        at once
108
109        The implementation does more complicated tracking using a count of sent
110        messages so that `msgs` can be a generator function (i.e. its length is
111        unknown)
112
113        """
114
115        b = Batch()
116
117        async def proc():
118            nonlocal b
119
120            await self._set_mode(mode)
121            return await b.process(self.socket, msgs)
122
123        return await self._send_wrapper(proc)
124
125    async def invoke(self, *args, skip_mode=False, **kwargs):
126        # It's OK if connection errors cause a failure here, because the mode
127        # is also reset to normal on a new connection
128        if not skip_mode:
129            await self._set_mode(self.MODE_NORMAL)
130        return await super().invoke(*args, **kwargs)
131
132    async def _set_mode(self, new_mode):
133        async def stream_to_normal():
134            # Check if already in normal mode (e.g. due to a connection reset)
135            if self.mode == self.MODE_NORMAL:
136                return "ok"
137            await self.socket.send("END")
138            return await self.socket.recv()
139
140        async def normal_to_stream(command):
141            r = await self.invoke({command: None}, skip_mode=True)
142            if r != "ok":
143                self.check_invoke_error(r)
144                raise ConnectionError(
145                    f"Unable to transition to stream mode: Bad response from server {r!r}"
146                )
147            self.logger.debug("Mode is now %s", command)
148
149        if new_mode == self.mode:
150            return
151
152        self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode)
153
154        # Always transition to normal mode before switching to any other mode
155        if self.mode != self.MODE_NORMAL:
156            r = await self._send_wrapper(stream_to_normal)
157            if r != "ok":
158                self.check_invoke_error(r)
159                raise ConnectionError(
160                    f"Unable to transition to normal mode: Bad response from server {r!r}"
161                )
162            self.logger.debug("Mode is now normal")
163
164        if new_mode == self.MODE_GET_STREAM:
165            await normal_to_stream("get-stream")
166        elif new_mode == self.MODE_EXIST_STREAM:
167            await normal_to_stream("exists-stream")
168        elif new_mode == self.MODE_MARK_STREAM:
169            await normal_to_stream("gc-mark-stream")
170        elif new_mode != self.MODE_NORMAL:
171            raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
172
173        self.mode = new_mode
174
175    async def get_unihash(self, method, taskhash):
176        r = await self.get_unihash_batch([(method, taskhash)])
177        return r[0]
178
179    async def get_unihash_batch(self, args):
180        result = await self.send_stream_batch(
181            self.MODE_GET_STREAM,
182            (f"{method} {taskhash}" for method, taskhash in args),
183        )
184        return [r if r else None for r in result]
185
186    async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
187        m = extra.copy()
188        m["taskhash"] = taskhash
189        m["method"] = method
190        m["outhash"] = outhash
191        m["unihash"] = unihash
192        return await self.invoke({"report": m})
193
194    async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
195        m = extra.copy()
196        m["taskhash"] = taskhash
197        m["method"] = method
198        m["unihash"] = unihash
199        return await self.invoke({"report-equiv": m})
200
201    async def get_taskhash(self, method, taskhash, all_properties=False):
202        return await self.invoke(
203            {"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
204        )
205
206    async def unihash_exists(self, unihash):
207        r = await self.unihash_exists_batch([unihash])
208        return r[0]
209
210    async def unihash_exists_batch(self, unihashes):
211        result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes)
212        return [r == "true" for r in result]
213
214    async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
215        return await self.invoke(
216            {
217                "get-outhash": {
218                    "outhash": outhash,
219                    "taskhash": taskhash,
220                    "method": method,
221                    "with_unihash": with_unihash,
222                }
223            }
224        )
225
226    async def get_stats(self):
227        return await self.invoke({"get-stats": None})
228
229    async def reset_stats(self):
230        return await self.invoke({"reset-stats": None})
231
232    async def backfill_wait(self):
233        return (await self.invoke({"backfill-wait": None}))["tasks"]
234
235    async def remove(self, where):
236        return await self.invoke({"remove": {"where": where}})
237
238    async def clean_unused(self, max_age):
239        return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
240
241    async def auth(self, username, token):
242        result = await self.invoke({"auth": {"username": username, "token": token}})
243        self.username = username
244        self.password = token
245        self.saved_become_user = None
246        return result
247
248    async def refresh_token(self, username=None):
249        m = {}
250        if username:
251            m["username"] = username
252        result = await self.invoke({"refresh-token": m})
253        if (
254            self.username
255            and not self.saved_become_user
256            and result["username"] == self.username
257        ):
258            self.password = result["token"]
259        return result
260
261    async def set_user_perms(self, username, permissions):
262        return await self.invoke(
263            {"set-user-perms": {"username": username, "permissions": permissions}}
264        )
265
266    async def get_user(self, username=None):
267        m = {}
268        if username:
269            m["username"] = username
270        return await self.invoke({"get-user": m})
271
272    async def get_all_users(self):
273        return (await self.invoke({"get-all-users": {}}))["users"]
274
275    async def new_user(self, username, permissions):
276        return await self.invoke(
277            {"new-user": {"username": username, "permissions": permissions}}
278        )
279
280    async def delete_user(self, username):
281        return await self.invoke({"delete-user": {"username": username}})
282
283    async def become_user(self, username):
284        result = await self.invoke({"become-user": {"username": username}})
285        if username == self.username:
286            self.saved_become_user = None
287        else:
288            self.saved_become_user = username
289        return result
290
291    async def get_db_usage(self):
292        return (await self.invoke({"get-db-usage": {}}))["usage"]
293
294    async def get_db_query_columns(self):
295        return (await self.invoke({"get-db-query-columns": {}}))["columns"]
296
297    async def gc_status(self):
298        return await self.invoke({"gc-status": {}})
299
300    async def gc_mark(self, mark, where):
301        """
302        Starts a new garbage collection operation identified by "mark". If
303        garbage collection is already in progress with "mark", the collection
304        is continued.
305
306        All unihash entries that match the "where" clause are marked to be
307        kept. In addition, any new entries added to the database after this
308        command will be automatically marked with "mark"
309        """
310        return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
311
312    async def gc_mark_stream(self, mark, rows):
313        """
314        Similar to `gc-mark`, but accepts a list of "where" key-value pair
315        conditions. It utilizes stream mode to mark hashes, which helps reduce
316        the impact of latency when communicating with the hash equivalence
317        server.
318        """
319        def row_to_dict(row):
320            pairs = row.split()
321            return dict(zip(pairs[::2], pairs[1::2]))
322
323        responses = await self.send_stream_batch(
324            self.MODE_MARK_STREAM,
325            (json.dumps({"mark": mark, "where": row_to_dict(row)}) for row in rows),
326        )
327
328        return {"count": sum(int(json.loads(r)["count"]) for r in responses)}
329
330    async def gc_sweep(self, mark):
331        """
332        Finishes garbage collection for "mark". All unihash entries that have
333        not been marked will be deleted.
334
335        It is recommended to clean unused outhash entries after running this to
336        cleanup any dangling outhashes
337        """
338        return await self.invoke({"gc-sweep": {"mark": mark}})
339
340
341class Client(bb.asyncrpc.Client):
342    def __init__(self, username=None, password=None):
343        self.username = username
344        self.password = password
345
346        super().__init__()
347        self._add_methods(
348            "connect_tcp",
349            "connect_websocket",
350            "get_unihash",
351            "get_unihash_batch",
352            "report_unihash",
353            "report_unihash_equiv",
354            "get_taskhash",
355            "unihash_exists",
356            "unihash_exists_batch",
357            "get_outhash",
358            "get_stats",
359            "reset_stats",
360            "backfill_wait",
361            "remove",
362            "clean_unused",
363            "auth",
364            "refresh_token",
365            "set_user_perms",
366            "get_user",
367            "get_all_users",
368            "new_user",
369            "delete_user",
370            "become_user",
371            "get_db_usage",
372            "get_db_query_columns",
373            "gc_status",
374            "gc_mark",
375            "gc_mark_stream",
376            "gc_sweep",
377        )
378
379    def _get_async_client(self):
380        return AsyncClient(self.username, self.password)
381