xref: /openbmc/openbmc/poky/bitbake/bin/bitbake-hashclient (revision c9537f57ab488bf5d90132917b0184e2527970a5)
1#! /usr/bin/env python3
2#
3# Copyright (C) 2019 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import argparse
9import hashlib
10import logging
11import os
12import pprint
13import sys
14import threading
15import time
16import warnings
17import netrc
18import json
19import statistics
20import textwrap
21warnings.simplefilter("default")
22
23try:
24    import tqdm
25    ProgressBar = tqdm.tqdm
26except ImportError:
27    class ProgressBar(object):
28        def __init__(self, *args, **kwargs):
29            pass
30
31        def __enter__(self):
32            return self
33
34        def __exit__(self, *args, **kwargs):
35            pass
36
37        def update(self):
38            pass
39
40sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
41
42import hashserv
43import bb.asyncrpc
44
45DEFAULT_ADDRESS = 'unix://./hashserve.sock'
46METHOD = 'stress.test.method'
47
48def print_user(u):
49    print(f"Username:    {u['username']}")
50    if "permissions" in u:
51        print("Permissions: " + " ".join(u["permissions"]))
52    if "token" in u:
53        print(f"Token:       {u['token']}")
54
55
56def main():
57    def handle_get(args, client):
58        result = client.get_taskhash(args.method, args.taskhash, all_properties=True)
59        if not result:
60            return 0
61
62        print(json.dumps(result, sort_keys=True, indent=4))
63        return 0
64
65    def handle_get_outhash(args, client):
66        result = client.get_outhash(args.method, args.outhash, args.taskhash)
67        if not result:
68            return 0
69
70        print(json.dumps(result, sort_keys=True, indent=4))
71        return 0
72
73    def handle_stats(args, client):
74        if args.reset:
75            s = client.reset_stats()
76        else:
77            s = client.get_stats()
78        print(json.dumps(s, sort_keys=True, indent=4))
79        return 0
80
81    def handle_stress(args, client):
82        def thread_main(pbar, lock):
83            nonlocal found_hashes
84            nonlocal missed_hashes
85            nonlocal max_time
86            nonlocal times
87
88            with hashserv.create_client(args.address) as client:
89                for i in range(args.requests):
90                    taskhash = hashlib.sha256()
91                    taskhash.update(args.taskhash_seed.encode('utf-8'))
92                    taskhash.update(str(i).encode('utf-8'))
93
94                    start_time = time.perf_counter()
95                    l = client.get_unihash(METHOD, taskhash.hexdigest())
96                    elapsed = time.perf_counter() - start_time
97
98                    with lock:
99                        if l:
100                            found_hashes += 1
101                        else:
102                            missed_hashes += 1
103
104                        times.append(elapsed)
105                        pbar.update()
106
107        max_time = 0
108        found_hashes = 0
109        missed_hashes = 0
110        lock = threading.Lock()
111        times = []
112        start_time = time.perf_counter()
113        with ProgressBar(total=args.clients * args.requests) as pbar:
114            threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)]
115            for t in threads:
116                t.start()
117
118            for t in threads:
119                t.join()
120        total_elapsed = time.perf_counter() - start_time
121
122        with lock:
123            mean = statistics.mean(times)
124            median = statistics.median(times)
125            stddev = statistics.pstdev(times)
126
127            print(f"Number of clients:    {args.clients}")
128            print(f"Requests per client:  {args.requests}")
129            print(f"Number of requests:   {len(times)}")
130            print(f"Total elapsed time:   {total_elapsed:.3f}s")
131            print(f"Total request rate:   {len(times)/total_elapsed:.3f} req/s")
132            print(f"Average request time: {mean:.3f}s")
133            print(f"Median request time:  {median:.3f}s")
134            print(f"Request time std dev: {stddev:.3f}s")
135            print(f"Maximum request time: {max(times):.3f}s")
136            print(f"Minimum request time: {min(times):.3f}s")
137            print(f"Hashes found:         {found_hashes}")
138            print(f"Hashes missed:        {missed_hashes}")
139
140        if args.report:
141            with ProgressBar(total=args.requests) as pbar:
142                for i in range(args.requests):
143                    taskhash = hashlib.sha256()
144                    taskhash.update(args.taskhash_seed.encode('utf-8'))
145                    taskhash.update(str(i).encode('utf-8'))
146
147                    outhash = hashlib.sha256()
148                    outhash.update(args.outhash_seed.encode('utf-8'))
149                    outhash.update(str(i).encode('utf-8'))
150
151                    client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest())
152
153                    with lock:
154                        pbar.update()
155
156    def handle_remove(args, client):
157        where = {k: v for k, v in args.where}
158        if where:
159            result = client.remove(where)
160            print("Removed %d row(s)" % (result["count"]))
161        else:
162            print("No query specified")
163
164    def handle_clean_unused(args, client):
165        result = client.clean_unused(args.max_age)
166        print("Removed %d rows" % (result["count"]))
167        return 0
168
169    def handle_refresh_token(args, client):
170        r = client.refresh_token(args.username)
171        print_user(r)
172
173    def handle_set_user_permissions(args, client):
174        r = client.set_user_perms(args.username, args.permissions)
175        print_user(r)
176
177    def handle_get_user(args, client):
178        r = client.get_user(args.username)
179        print_user(r)
180
181    def handle_get_all_users(args, client):
182        users = client.get_all_users()
183        print("{username:20}| {permissions}".format(username="Username", permissions="Permissions"))
184        print(("-" * 20) + "+" + ("-" * 20))
185        for u in users:
186            print("{username:20}| {permissions}".format(username=u["username"], permissions=" ".join(u["permissions"])))
187
188    def handle_new_user(args, client):
189        r = client.new_user(args.username, args.permissions)
190        print_user(r)
191
192    def handle_delete_user(args, client):
193        r = client.delete_user(args.username)
194        print_user(r)
195
196    def handle_get_db_usage(args, client):
197        usage = client.get_db_usage()
198        print(usage)
199        tables = sorted(usage.keys())
200        print("{name:20}| {rows:20}".format(name="Table name", rows="Rows"))
201        print(("-" * 20) + "+" + ("-" * 20))
202        for t in tables:
203            print("{name:20}| {rows:<20}".format(name=t, rows=usage[t]["rows"]))
204        print()
205
206        total_rows = sum(t["rows"] for t in usage.values())
207        print(f"Total rows: {total_rows}")
208
209    def handle_get_db_query_columns(args, client):
210        columns = client.get_db_query_columns()
211        print("\n".join(sorted(columns)))
212
213    def handle_gc_status(args, client):
214        result = client.gc_status()
215        if not result["mark"]:
216            print("No Garbage collection in progress")
217            return 0
218
219        print("Current Mark: %s" % result["mark"])
220        print("Total hashes to keep: %d" % result["keep"])
221        print("Total hashes to remove: %s" % result["remove"])
222        return 0
223
224    def handle_gc_mark(args, client):
225        where = {k: v for k, v in args.where}
226        result = client.gc_mark(args.mark, where)
227        print("New hashes marked: %d" % result["count"])
228        return 0
229
230    def handle_gc_mark_stream(args, client):
231        stdin = (l.strip() for l in sys.stdin)
232        marked_hashes = 0
233
234        try:
235            result = client.gc_mark_stream(args.mark, stdin)
236            marked_hashes = result["count"]
237        except ConnectionError:
238            logger.warning(
239                "Server doesn't seem to support `gc-mark-stream`. Sending "
240                "hashes sequentially using `gc-mark` API."
241            )
242            for line in stdin:
243                pairs = line.split()
244                condition = dict(zip(pairs[::2], pairs[1::2]))
245                result = client.gc_mark(args.mark, condition)
246                marked_hashes += result["count"]
247
248        print("New hashes marked: %d" % marked_hashes)
249        return 0
250
251    def handle_gc_sweep(args, client):
252        result = client.gc_sweep(args.mark)
253        print("Removed %d rows" % result["count"])
254        return 0
255
256    def handle_unihash_exists(args, client):
257        result = client.unihash_exists(args.unihash)
258        if args.quiet:
259            return 0 if result else 1
260
261        print("true" if result else "false")
262        return 0
263
264    def handle_ping(args, client):
265        times = []
266        for i in range(1, args.count + 1):
267            if not args.quiet:
268                print(f"Ping {i} of {args.count}... ", end="")
269            start_time = time.perf_counter()
270            client.ping()
271            elapsed = time.perf_counter() - start_time
272            times.append(elapsed)
273            if not args.quiet:
274                print(f"{elapsed:.3f}s")
275
276        mean = statistics.mean(times)
277        median = statistics.median(times)
278        std_dev = statistics.pstdev(times)
279
280        if not args.quiet:
281            print("------------------------")
282        print(f"Number of pings:         {len(times)}")
283        print(f"Average round trip time: {mean:.3f}s")
284        print(f"Median round trip time:  {median:.3f}s")
285        print(f"Round trip time std dev: {std_dev:.3f}s")
286        print(f"Min time is:             {min(times):.3f}s")
287        print(f"Max time is:             {max(times):.3f}s")
288        return 0
289
290    parser = argparse.ArgumentParser(
291        formatter_class=argparse.RawDescriptionHelpFormatter,
292        description='Hash Equivalence Client',
293        epilog=textwrap.dedent(
294            """
295            Possible ADDRESS options are:
296                unix://PATH         Connect to UNIX domain socket at PATH
297                ws://HOST[:PORT]    Connect to websocket at HOST:PORT (default port is 80)
298                wss://HOST[:PORT]   Connect to secure websocket at HOST:PORT (default port is 443)
299                HOST:PORT           Connect to TCP server at HOST:PORT
300            """
301        ),
302    )
303    parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
304    parser.add_argument('--log', default='WARNING', help='Set logging level')
305    parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME")
306    parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN")
307    parser.add_argument('--become', '-b', metavar="USERNAME", help="Impersonate user USERNAME (if allowed) when performing actions")
308    parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc")
309
310    subparsers = parser.add_subparsers()
311
312    get_parser = subparsers.add_parser('get', help="Get the unihash for a taskhash")
313    get_parser.add_argument("method", help="Method to query")
314    get_parser.add_argument("taskhash", help="Task hash to query")
315    get_parser.set_defaults(func=handle_get)
316
317    get_outhash_parser = subparsers.add_parser('get-outhash', help="Get output hash information")
318    get_outhash_parser.add_argument("method", help="Method to query")
319    get_outhash_parser.add_argument("outhash", help="Output hash to query")
320    get_outhash_parser.add_argument("taskhash", help="Task hash to query")
321    get_outhash_parser.set_defaults(func=handle_get_outhash)
322
323    stats_parser = subparsers.add_parser('stats', help='Show server stats')
324    stats_parser.add_argument('--reset', action='store_true',
325                              help='Reset server stats')
326    stats_parser.set_defaults(func=handle_stats)
327
328    stress_parser = subparsers.add_parser('stress', help='Run stress test')
329    stress_parser.add_argument('--clients', type=int, default=10,
330                               help='Number of simultaneous clients')
331    stress_parser.add_argument('--requests', type=int, default=1000,
332                               help='Number of requests each client will perform')
333    stress_parser.add_argument('--report', action='store_true',
334                               help='Report new hashes')
335    stress_parser.add_argument('--taskhash-seed', default='',
336                               help='Include string in taskhash')
337    stress_parser.add_argument('--outhash-seed', default='',
338                               help='Include string in outhash')
339    stress_parser.set_defaults(func=handle_stress)
340
341    remove_parser = subparsers.add_parser('remove', help="Remove hash entries")
342    remove_parser.add_argument("--where", "-w", metavar="KEY VALUE", nargs=2, action="append", default=[],
343                               help="Remove entries from table where KEY == VALUE")
344    remove_parser.set_defaults(func=handle_remove)
345
346    clean_unused_parser = subparsers.add_parser('clean-unused', help="Remove unused database entries")
347    clean_unused_parser.add_argument("max_age", metavar="SECONDS", type=int, help="Remove unused entries older than SECONDS old")
348    clean_unused_parser.set_defaults(func=handle_clean_unused)
349
350    refresh_token_parser = subparsers.add_parser('refresh-token', help="Refresh auth token")
351    refresh_token_parser.add_argument("--username", "-u", help="Refresh the token for another user (if authorized)")
352    refresh_token_parser.set_defaults(func=handle_refresh_token)
353
354    set_user_perms_parser = subparsers.add_parser('set-user-perms', help="Set new permissions for user")
355    set_user_perms_parser.add_argument("--username", "-u", help="Username", required=True)
356    set_user_perms_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
357    set_user_perms_parser.set_defaults(func=handle_set_user_permissions)
358
359    get_user_parser = subparsers.add_parser('get-user', help="Get user")
360    get_user_parser.add_argument("--username", "-u", help="Username")
361    get_user_parser.set_defaults(func=handle_get_user)
362
363    get_all_users_parser = subparsers.add_parser('get-all-users', help="List all users")
364    get_all_users_parser.set_defaults(func=handle_get_all_users)
365
366    new_user_parser = subparsers.add_parser('new-user', help="Create new user")
367    new_user_parser.add_argument("--username", "-u", help="Username", required=True)
368    new_user_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions")
369    new_user_parser.set_defaults(func=handle_new_user)
370
371    delete_user_parser = subparsers.add_parser('delete-user', help="Delete user")
372    delete_user_parser.add_argument("--username", "-u", help="Username", required=True)
373    delete_user_parser.set_defaults(func=handle_delete_user)
374
375    db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage")
376    db_usage_parser.set_defaults(func=handle_get_db_usage)
377
378    db_query_columns_parser = subparsers.add_parser('get-db-query-columns', help="Show columns that can be used in database queries")
379    db_query_columns_parser.set_defaults(func=handle_get_db_query_columns)
380
381    gc_status_parser = subparsers.add_parser("gc-status", help="Show garbage collection status")
382    gc_status_parser.set_defaults(func=handle_gc_status)
383
384    gc_mark_parser = subparsers.add_parser('gc-mark', help="Mark hashes to be kept for garbage collection")
385    gc_mark_parser.add_argument("mark", help="Mark for this garbage collection operation")
386    gc_mark_parser.add_argument("--where", "-w", metavar="KEY VALUE", nargs=2, action="append", default=[],
387                             help="Keep entries in table where KEY == VALUE")
388    gc_mark_parser.set_defaults(func=handle_gc_mark)
389
390    gc_mark_parser_stream = subparsers.add_parser(
391        'gc-mark-stream',
392        help=(
393            "Mark multiple hashes to be retained for garbage collection. Input should be provided via stdin, "
394            "with each line formatted as key-value pairs separated by spaces, for example 'column1 foo column2 bar'."
395        )
396    )
397    gc_mark_parser_stream.add_argument("mark", help="Mark for this garbage collection operation")
398    gc_mark_parser_stream.set_defaults(func=handle_gc_mark_stream)
399
400    gc_sweep_parser = subparsers.add_parser('gc-sweep', help="Perform garbage collection and delete any entries that are not marked")
401    gc_sweep_parser.add_argument("mark", help="Mark for this garbage collection operation")
402    gc_sweep_parser.set_defaults(func=handle_gc_sweep)
403
404    unihash_exists_parser = subparsers.add_parser('unihash-exists', help="Check if a unihash is known to the server")
405    unihash_exists_parser.add_argument("--quiet", action="store_true", help="Don't print status. Instead, exit with 0 if unihash exists and 1 if it does not")
406    unihash_exists_parser.add_argument("unihash", help="Unihash to check")
407    unihash_exists_parser.set_defaults(func=handle_unihash_exists)
408
409    ping_parser = subparsers.add_parser('ping', help="Ping server")
410    ping_parser.add_argument("-n", "--count", type=int, help="Number of pings. Default is %(default)s", default=10)
411    ping_parser.add_argument("-q", "--quiet", action="store_true", help="Don't print each ping; only print results")
412    ping_parser.set_defaults(func=handle_ping)
413
414    args = parser.parse_args()
415
416    logger = logging.getLogger('hashserv')
417
418    level = getattr(logging, args.log.upper(), None)
419    if not isinstance(level, int):
420        raise ValueError('Invalid log level: %s' % args.log)
421
422    logger.setLevel(level)
423    console = logging.StreamHandler()
424    console.setLevel(level)
425    logger.addHandler(console)
426
427    login = args.login
428    password = args.password
429
430    if login is None and args.netrc:
431        try:
432            n = netrc.netrc()
433            auth = n.authenticators(args.address)
434            if auth is not None:
435                login, _, password = auth
436        except FileNotFoundError:
437            pass
438        except netrc.NetrcParseError as e:
439            sys.stderr.write(f"Error parsing {e.filename}:{e.lineno}: {e.msg}\n")
440
441    func = getattr(args, 'func', None)
442    if func:
443        try:
444            with hashserv.create_client(args.address, login, password) as client:
445                if args.become:
446                    client.become_user(args.become)
447                return func(args, client)
448        except bb.asyncrpc.InvokeError as e:
449            print(f"ERROR: {e}")
450            return 1
451
452    return 0
453
454
455if __name__ == '__main__':
456    try:
457        ret = main()
458    except Exception:
459        ret = 1
460        import traceback
461        traceback.print_exc()
462    sys.exit(ret)
463