1#! /usr/bin/env python3 2# 3# Copyright (C) 2019 Garmin Ltd. 4# 5# SPDX-License-Identifier: GPL-2.0-only 6# 7 8import argparse 9import hashlib 10import logging 11import os 12import pprint 13import sys 14import threading 15import time 16import warnings 17import netrc 18import json 19import statistics 20import textwrap 21warnings.simplefilter("default") 22 23try: 24 import tqdm 25 ProgressBar = tqdm.tqdm 26except ImportError: 27 class ProgressBar(object): 28 def __init__(self, *args, **kwargs): 29 pass 30 31 def __enter__(self): 32 return self 33 34 def __exit__(self, *args, **kwargs): 35 pass 36 37 def update(self): 38 pass 39 40sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib')) 41 42import hashserv 43import bb.asyncrpc 44 45DEFAULT_ADDRESS = 'unix://./hashserve.sock' 46METHOD = 'stress.test.method' 47 48def print_user(u): 49 print(f"Username: {u['username']}") 50 if "permissions" in u: 51 print("Permissions: " + " ".join(u["permissions"])) 52 if "token" in u: 53 print(f"Token: {u['token']}") 54 55 56def main(): 57 def handle_get(args, client): 58 result = client.get_taskhash(args.method, args.taskhash, all_properties=True) 59 if not result: 60 return 0 61 62 print(json.dumps(result, sort_keys=True, indent=4)) 63 return 0 64 65 def handle_get_outhash(args, client): 66 result = client.get_outhash(args.method, args.outhash, args.taskhash) 67 if not result: 68 return 0 69 70 print(json.dumps(result, sort_keys=True, indent=4)) 71 return 0 72 73 def handle_stats(args, client): 74 if args.reset: 75 s = client.reset_stats() 76 else: 77 s = client.get_stats() 78 print(json.dumps(s, sort_keys=True, indent=4)) 79 return 0 80 81 def handle_stress(args, client): 82 def thread_main(pbar, lock): 83 nonlocal found_hashes 84 nonlocal missed_hashes 85 nonlocal max_time 86 nonlocal times 87 88 with hashserv.create_client(args.address) as client: 89 for i in range(args.requests): 90 taskhash = hashlib.sha256() 91 taskhash.update(args.taskhash_seed.encode('utf-8')) 92 taskhash.update(str(i).encode('utf-8')) 93 94 start_time = time.perf_counter() 95 l = client.get_unihash(METHOD, taskhash.hexdigest()) 96 elapsed = time.perf_counter() - start_time 97 98 with lock: 99 if l: 100 found_hashes += 1 101 else: 102 missed_hashes += 1 103 104 times.append(elapsed) 105 pbar.update() 106 107 max_time = 0 108 found_hashes = 0 109 missed_hashes = 0 110 lock = threading.Lock() 111 times = [] 112 start_time = time.perf_counter() 113 with ProgressBar(total=args.clients * args.requests) as pbar: 114 threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)] 115 for t in threads: 116 t.start() 117 118 for t in threads: 119 t.join() 120 total_elapsed = time.perf_counter() - start_time 121 122 with lock: 123 mean = statistics.mean(times) 124 median = statistics.median(times) 125 stddev = statistics.pstdev(times) 126 127 print(f"Number of clients: {args.clients}") 128 print(f"Requests per client: {args.requests}") 129 print(f"Number of requests: {len(times)}") 130 print(f"Total elapsed time: {total_elapsed:.3f}s") 131 print(f"Total request rate: {len(times)/total_elapsed:.3f} req/s") 132 print(f"Average request time: {mean:.3f}s") 133 print(f"Median request time: {median:.3f}s") 134 print(f"Request time std dev: {stddev:.3f}s") 135 print(f"Maximum request time: {max(times):.3f}s") 136 print(f"Minimum request time: {min(times):.3f}s") 137 print(f"Hashes found: {found_hashes}") 138 print(f"Hashes missed: {missed_hashes}") 139 140 if args.report: 141 with ProgressBar(total=args.requests) as pbar: 142 for i in range(args.requests): 143 taskhash = hashlib.sha256() 144 taskhash.update(args.taskhash_seed.encode('utf-8')) 145 taskhash.update(str(i).encode('utf-8')) 146 147 outhash = hashlib.sha256() 148 outhash.update(args.outhash_seed.encode('utf-8')) 149 outhash.update(str(i).encode('utf-8')) 150 151 client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest()) 152 153 with lock: 154 pbar.update() 155 156 def handle_remove(args, client): 157 where = {k: v for k, v in args.where} 158 if where: 159 result = client.remove(where) 160 print("Removed %d row(s)" % (result["count"])) 161 else: 162 print("No query specified") 163 164 def handle_clean_unused(args, client): 165 result = client.clean_unused(args.max_age) 166 print("Removed %d rows" % (result["count"])) 167 return 0 168 169 def handle_refresh_token(args, client): 170 r = client.refresh_token(args.username) 171 print_user(r) 172 173 def handle_set_user_permissions(args, client): 174 r = client.set_user_perms(args.username, args.permissions) 175 print_user(r) 176 177 def handle_get_user(args, client): 178 r = client.get_user(args.username) 179 print_user(r) 180 181 def handle_get_all_users(args, client): 182 users = client.get_all_users() 183 print("{username:20}| {permissions}".format(username="Username", permissions="Permissions")) 184 print(("-" * 20) + "+" + ("-" * 20)) 185 for u in users: 186 print("{username:20}| {permissions}".format(username=u["username"], permissions=" ".join(u["permissions"]))) 187 188 def handle_new_user(args, client): 189 r = client.new_user(args.username, args.permissions) 190 print_user(r) 191 192 def handle_delete_user(args, client): 193 r = client.delete_user(args.username) 194 print_user(r) 195 196 def handle_get_db_usage(args, client): 197 usage = client.get_db_usage() 198 print(usage) 199 tables = sorted(usage.keys()) 200 print("{name:20}| {rows:20}".format(name="Table name", rows="Rows")) 201 print(("-" * 20) + "+" + ("-" * 20)) 202 for t in tables: 203 print("{name:20}| {rows:<20}".format(name=t, rows=usage[t]["rows"])) 204 print() 205 206 total_rows = sum(t["rows"] for t in usage.values()) 207 print(f"Total rows: {total_rows}") 208 209 def handle_get_db_query_columns(args, client): 210 columns = client.get_db_query_columns() 211 print("\n".join(sorted(columns))) 212 213 def handle_gc_status(args, client): 214 result = client.gc_status() 215 if not result["mark"]: 216 print("No Garbage collection in progress") 217 return 0 218 219 print("Current Mark: %s" % result["mark"]) 220 print("Total hashes to keep: %d" % result["keep"]) 221 print("Total hashes to remove: %s" % result["remove"]) 222 return 0 223 224 def handle_gc_mark(args, client): 225 where = {k: v for k, v in args.where} 226 result = client.gc_mark(args.mark, where) 227 print("New hashes marked: %d" % result["count"]) 228 return 0 229 230 def handle_gc_mark_stream(args, client): 231 stdin = (l.strip() for l in sys.stdin) 232 marked_hashes = 0 233 234 try: 235 result = client.gc_mark_stream(args.mark, stdin) 236 marked_hashes = result["count"] 237 except ConnectionError: 238 logger.warning( 239 "Server doesn't seem to support `gc-mark-stream`. Sending " 240 "hashes sequentially using `gc-mark` API." 241 ) 242 for line in stdin: 243 pairs = line.split() 244 condition = dict(zip(pairs[::2], pairs[1::2])) 245 result = client.gc_mark(args.mark, condition) 246 marked_hashes += result["count"] 247 248 print("New hashes marked: %d" % marked_hashes) 249 return 0 250 251 def handle_gc_sweep(args, client): 252 result = client.gc_sweep(args.mark) 253 print("Removed %d rows" % result["count"]) 254 return 0 255 256 def handle_unihash_exists(args, client): 257 result = client.unihash_exists(args.unihash) 258 if args.quiet: 259 return 0 if result else 1 260 261 print("true" if result else "false") 262 return 0 263 264 def handle_ping(args, client): 265 times = [] 266 for i in range(1, args.count + 1): 267 if not args.quiet: 268 print(f"Ping {i} of {args.count}... ", end="") 269 start_time = time.perf_counter() 270 client.ping() 271 elapsed = time.perf_counter() - start_time 272 times.append(elapsed) 273 if not args.quiet: 274 print(f"{elapsed:.3f}s") 275 276 mean = statistics.mean(times) 277 median = statistics.median(times) 278 std_dev = statistics.pstdev(times) 279 280 if not args.quiet: 281 print("------------------------") 282 print(f"Number of pings: {len(times)}") 283 print(f"Average round trip time: {mean:.3f}s") 284 print(f"Median round trip time: {median:.3f}s") 285 print(f"Round trip time std dev: {std_dev:.3f}s") 286 print(f"Min time is: {min(times):.3f}s") 287 print(f"Max time is: {max(times):.3f}s") 288 return 0 289 290 parser = argparse.ArgumentParser( 291 formatter_class=argparse.RawDescriptionHelpFormatter, 292 description='Hash Equivalence Client', 293 epilog=textwrap.dedent( 294 """ 295 Possible ADDRESS options are: 296 unix://PATH Connect to UNIX domain socket at PATH 297 ws://HOST[:PORT] Connect to websocket at HOST:PORT (default port is 80) 298 wss://HOST[:PORT] Connect to secure websocket at HOST:PORT (default port is 443) 299 HOST:PORT Connect to TCP server at HOST:PORT 300 """ 301 ), 302 ) 303 parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")') 304 parser.add_argument('--log', default='WARNING', help='Set logging level') 305 parser.add_argument('--login', '-l', metavar="USERNAME", help="Authenticate as USERNAME") 306 parser.add_argument('--password', '-p', metavar="TOKEN", help="Authenticate using token TOKEN") 307 parser.add_argument('--become', '-b', metavar="USERNAME", help="Impersonate user USERNAME (if allowed) when performing actions") 308 parser.add_argument('--no-netrc', '-n', action="store_false", dest="netrc", help="Do not use .netrc") 309 310 subparsers = parser.add_subparsers() 311 312 get_parser = subparsers.add_parser('get', help="Get the unihash for a taskhash") 313 get_parser.add_argument("method", help="Method to query") 314 get_parser.add_argument("taskhash", help="Task hash to query") 315 get_parser.set_defaults(func=handle_get) 316 317 get_outhash_parser = subparsers.add_parser('get-outhash', help="Get output hash information") 318 get_outhash_parser.add_argument("method", help="Method to query") 319 get_outhash_parser.add_argument("outhash", help="Output hash to query") 320 get_outhash_parser.add_argument("taskhash", help="Task hash to query") 321 get_outhash_parser.set_defaults(func=handle_get_outhash) 322 323 stats_parser = subparsers.add_parser('stats', help='Show server stats') 324 stats_parser.add_argument('--reset', action='store_true', 325 help='Reset server stats') 326 stats_parser.set_defaults(func=handle_stats) 327 328 stress_parser = subparsers.add_parser('stress', help='Run stress test') 329 stress_parser.add_argument('--clients', type=int, default=10, 330 help='Number of simultaneous clients') 331 stress_parser.add_argument('--requests', type=int, default=1000, 332 help='Number of requests each client will perform') 333 stress_parser.add_argument('--report', action='store_true', 334 help='Report new hashes') 335 stress_parser.add_argument('--taskhash-seed', default='', 336 help='Include string in taskhash') 337 stress_parser.add_argument('--outhash-seed', default='', 338 help='Include string in outhash') 339 stress_parser.set_defaults(func=handle_stress) 340 341 remove_parser = subparsers.add_parser('remove', help="Remove hash entries") 342 remove_parser.add_argument("--where", "-w", metavar="KEY VALUE", nargs=2, action="append", default=[], 343 help="Remove entries from table where KEY == VALUE") 344 remove_parser.set_defaults(func=handle_remove) 345 346 clean_unused_parser = subparsers.add_parser('clean-unused', help="Remove unused database entries") 347 clean_unused_parser.add_argument("max_age", metavar="SECONDS", type=int, help="Remove unused entries older than SECONDS old") 348 clean_unused_parser.set_defaults(func=handle_clean_unused) 349 350 refresh_token_parser = subparsers.add_parser('refresh-token', help="Refresh auth token") 351 refresh_token_parser.add_argument("--username", "-u", help="Refresh the token for another user (if authorized)") 352 refresh_token_parser.set_defaults(func=handle_refresh_token) 353 354 set_user_perms_parser = subparsers.add_parser('set-user-perms', help="Set new permissions for user") 355 set_user_perms_parser.add_argument("--username", "-u", help="Username", required=True) 356 set_user_perms_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions") 357 set_user_perms_parser.set_defaults(func=handle_set_user_permissions) 358 359 get_user_parser = subparsers.add_parser('get-user', help="Get user") 360 get_user_parser.add_argument("--username", "-u", help="Username") 361 get_user_parser.set_defaults(func=handle_get_user) 362 363 get_all_users_parser = subparsers.add_parser('get-all-users', help="List all users") 364 get_all_users_parser.set_defaults(func=handle_get_all_users) 365 366 new_user_parser = subparsers.add_parser('new-user', help="Create new user") 367 new_user_parser.add_argument("--username", "-u", help="Username", required=True) 368 new_user_parser.add_argument("permissions", metavar="PERM", nargs="*", default=[], help="New permissions") 369 new_user_parser.set_defaults(func=handle_new_user) 370 371 delete_user_parser = subparsers.add_parser('delete-user', help="Delete user") 372 delete_user_parser.add_argument("--username", "-u", help="Username", required=True) 373 delete_user_parser.set_defaults(func=handle_delete_user) 374 375 db_usage_parser = subparsers.add_parser('get-db-usage', help="Database Usage") 376 db_usage_parser.set_defaults(func=handle_get_db_usage) 377 378 db_query_columns_parser = subparsers.add_parser('get-db-query-columns', help="Show columns that can be used in database queries") 379 db_query_columns_parser.set_defaults(func=handle_get_db_query_columns) 380 381 gc_status_parser = subparsers.add_parser("gc-status", help="Show garbage collection status") 382 gc_status_parser.set_defaults(func=handle_gc_status) 383 384 gc_mark_parser = subparsers.add_parser('gc-mark', help="Mark hashes to be kept for garbage collection") 385 gc_mark_parser.add_argument("mark", help="Mark for this garbage collection operation") 386 gc_mark_parser.add_argument("--where", "-w", metavar="KEY VALUE", nargs=2, action="append", default=[], 387 help="Keep entries in table where KEY == VALUE") 388 gc_mark_parser.set_defaults(func=handle_gc_mark) 389 390 gc_mark_parser_stream = subparsers.add_parser( 391 'gc-mark-stream', 392 help=( 393 "Mark multiple hashes to be retained for garbage collection. Input should be provided via stdin, " 394 "with each line formatted as key-value pairs separated by spaces, for example 'column1 foo column2 bar'." 395 ) 396 ) 397 gc_mark_parser_stream.add_argument("mark", help="Mark for this garbage collection operation") 398 gc_mark_parser_stream.set_defaults(func=handle_gc_mark_stream) 399 400 gc_sweep_parser = subparsers.add_parser('gc-sweep', help="Perform garbage collection and delete any entries that are not marked") 401 gc_sweep_parser.add_argument("mark", help="Mark for this garbage collection operation") 402 gc_sweep_parser.set_defaults(func=handle_gc_sweep) 403 404 unihash_exists_parser = subparsers.add_parser('unihash-exists', help="Check if a unihash is known to the server") 405 unihash_exists_parser.add_argument("--quiet", action="store_true", help="Don't print status. Instead, exit with 0 if unihash exists and 1 if it does not") 406 unihash_exists_parser.add_argument("unihash", help="Unihash to check") 407 unihash_exists_parser.set_defaults(func=handle_unihash_exists) 408 409 ping_parser = subparsers.add_parser('ping', help="Ping server") 410 ping_parser.add_argument("-n", "--count", type=int, help="Number of pings. Default is %(default)s", default=10) 411 ping_parser.add_argument("-q", "--quiet", action="store_true", help="Don't print each ping; only print results") 412 ping_parser.set_defaults(func=handle_ping) 413 414 args = parser.parse_args() 415 416 logger = logging.getLogger('hashserv') 417 418 level = getattr(logging, args.log.upper(), None) 419 if not isinstance(level, int): 420 raise ValueError('Invalid log level: %s' % args.log) 421 422 logger.setLevel(level) 423 console = logging.StreamHandler() 424 console.setLevel(level) 425 logger.addHandler(console) 426 427 login = args.login 428 password = args.password 429 430 if login is None and args.netrc: 431 try: 432 n = netrc.netrc() 433 auth = n.authenticators(args.address) 434 if auth is not None: 435 login, _, password = auth 436 except FileNotFoundError: 437 pass 438 except netrc.NetrcParseError as e: 439 sys.stderr.write(f"Error parsing {e.filename}:{e.lineno}: {e.msg}\n") 440 441 func = getattr(args, 'func', None) 442 if func: 443 try: 444 with hashserv.create_client(args.address, login, password) as client: 445 if args.become: 446 client.become_user(args.become) 447 return func(args, client) 448 except bb.asyncrpc.InvokeError as e: 449 print(f"ERROR: {e}") 450 return 1 451 452 return 0 453 454 455if __name__ == '__main__': 456 try: 457 ret = main() 458 except Exception: 459 ret = 1 460 import traceback 461 traceback.print_exc() 462 sys.exit(ret) 463