1#! /usr/bin/env python3 2# 3# Copyright (C) 2019 Garmin Ltd. 4# 5# SPDX-License-Identifier: GPL-2.0-only 6# 7 8import argparse 9import hashlib 10import logging 11import os 12import pprint 13import sys 14import threading 15import time 16 17try: 18 import tqdm 19 ProgressBar = tqdm.tqdm 20except ImportError: 21 class ProgressBar(object): 22 def __init__(self, *args, **kwargs): 23 pass 24 25 def __enter__(self): 26 return self 27 28 def __exit__(self, *args, **kwargs): 29 pass 30 31 def update(self): 32 pass 33 34sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib')) 35 36import hashserv 37 38DEFAULT_ADDRESS = 'unix://./hashserve.sock' 39METHOD = 'stress.test.method' 40 41 42def main(): 43 def handle_stats(args, client): 44 if args.reset: 45 s = client.reset_stats() 46 else: 47 s = client.get_stats() 48 pprint.pprint(s) 49 return 0 50 51 def handle_stress(args, client): 52 def thread_main(pbar, lock): 53 nonlocal found_hashes 54 nonlocal missed_hashes 55 nonlocal max_time 56 57 client = hashserv.create_client(args.address) 58 59 for i in range(args.requests): 60 taskhash = hashlib.sha256() 61 taskhash.update(args.taskhash_seed.encode('utf-8')) 62 taskhash.update(str(i).encode('utf-8')) 63 64 start_time = time.perf_counter() 65 l = client.get_unihash(METHOD, taskhash.hexdigest()) 66 elapsed = time.perf_counter() - start_time 67 68 with lock: 69 if l: 70 found_hashes += 1 71 else: 72 missed_hashes += 1 73 74 max_time = max(elapsed, max_time) 75 pbar.update() 76 77 max_time = 0 78 found_hashes = 0 79 missed_hashes = 0 80 lock = threading.Lock() 81 total_requests = args.clients * args.requests 82 start_time = time.perf_counter() 83 with ProgressBar(total=total_requests) as pbar: 84 threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)] 85 for t in threads: 86 t.start() 87 88 for t in threads: 89 t.join() 90 91 elapsed = time.perf_counter() - start_time 92 with lock: 93 print("%d requests in %.1fs. %.1f requests per second" % (total_requests, elapsed, total_requests / elapsed)) 94 print("Average request time %.8fs" % (elapsed / total_requests)) 95 print("Max request time was %.8fs" % max_time) 96 print("Found %d hashes, missed %d" % (found_hashes, missed_hashes)) 97 98 if args.report: 99 with ProgressBar(total=args.requests) as pbar: 100 for i in range(args.requests): 101 taskhash = hashlib.sha256() 102 taskhash.update(args.taskhash_seed.encode('utf-8')) 103 taskhash.update(str(i).encode('utf-8')) 104 105 outhash = hashlib.sha256() 106 outhash.update(args.outhash_seed.encode('utf-8')) 107 outhash.update(str(i).encode('utf-8')) 108 109 client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest()) 110 111 with lock: 112 pbar.update() 113 114 parser = argparse.ArgumentParser(description='Hash Equivalence Client') 115 parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")') 116 parser.add_argument('--log', default='WARNING', help='Set logging level') 117 118 subparsers = parser.add_subparsers() 119 120 stats_parser = subparsers.add_parser('stats', help='Show server stats') 121 stats_parser.add_argument('--reset', action='store_true', 122 help='Reset server stats') 123 stats_parser.set_defaults(func=handle_stats) 124 125 stress_parser = subparsers.add_parser('stress', help='Run stress test') 126 stress_parser.add_argument('--clients', type=int, default=10, 127 help='Number of simultaneous clients') 128 stress_parser.add_argument('--requests', type=int, default=1000, 129 help='Number of requests each client will perform') 130 stress_parser.add_argument('--report', action='store_true', 131 help='Report new hashes') 132 stress_parser.add_argument('--taskhash-seed', default='', 133 help='Include string in taskhash') 134 stress_parser.add_argument('--outhash-seed', default='', 135 help='Include string in outhash') 136 stress_parser.set_defaults(func=handle_stress) 137 138 args = parser.parse_args() 139 140 logger = logging.getLogger('hashserv') 141 142 level = getattr(logging, args.log.upper(), None) 143 if not isinstance(level, int): 144 raise ValueError('Invalid log level: %s' % args.log) 145 146 logger.setLevel(level) 147 console = logging.StreamHandler() 148 console.setLevel(level) 149 logger.addHandler(console) 150 151 func = getattr(args, 'func', None) 152 if func: 153 client = hashserv.create_client(args.address) 154 155 return func(args, client) 156 157 return 0 158 159 160if __name__ == '__main__': 161 try: 162 ret = main() 163 except Exception: 164 ret = 1 165 import traceback 166 traceback.print_exc() 167 sys.exit(ret) 168