1#! /usr/bin/env python3
2#
3# Copyright (C) 2019 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import argparse
9import hashlib
10import logging
11import os
12import pprint
13import sys
14import threading
15import time
16
17try:
18    import tqdm
19    ProgressBar = tqdm.tqdm
20except ImportError:
21    class ProgressBar(object):
22        def __init__(self, *args, **kwargs):
23            pass
24
25        def __enter__(self):
26            return self
27
28        def __exit__(self, *args, **kwargs):
29            pass
30
31        def update(self):
32            pass
33
34sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
35
36import hashserv
37
38DEFAULT_ADDRESS = 'unix://./hashserve.sock'
39METHOD = 'stress.test.method'
40
41
42def main():
43    def handle_stats(args, client):
44        if args.reset:
45            s = client.reset_stats()
46        else:
47            s = client.get_stats()
48        pprint.pprint(s)
49        return 0
50
51    def handle_stress(args, client):
52        def thread_main(pbar, lock):
53            nonlocal found_hashes
54            nonlocal missed_hashes
55            nonlocal max_time
56
57            client = hashserv.create_client(args.address)
58
59            for i in range(args.requests):
60                taskhash = hashlib.sha256()
61                taskhash.update(args.taskhash_seed.encode('utf-8'))
62                taskhash.update(str(i).encode('utf-8'))
63
64                start_time = time.perf_counter()
65                l = client.get_unihash(METHOD, taskhash.hexdigest())
66                elapsed = time.perf_counter() - start_time
67
68                with lock:
69                    if l:
70                        found_hashes += 1
71                    else:
72                        missed_hashes += 1
73
74                    max_time = max(elapsed, max_time)
75                    pbar.update()
76
77        max_time = 0
78        found_hashes = 0
79        missed_hashes = 0
80        lock = threading.Lock()
81        total_requests = args.clients * args.requests
82        start_time = time.perf_counter()
83        with ProgressBar(total=total_requests) as pbar:
84            threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)]
85            for t in threads:
86                t.start()
87
88            for t in threads:
89                t.join()
90
91        elapsed = time.perf_counter() - start_time
92        with lock:
93            print("%d requests in %.1fs. %.1f requests per second" % (total_requests, elapsed, total_requests / elapsed))
94            print("Average request time %.8fs" % (elapsed / total_requests))
95            print("Max request time was %.8fs" % max_time)
96            print("Found %d hashes, missed %d" % (found_hashes, missed_hashes))
97
98        if args.report:
99            with ProgressBar(total=args.requests) as pbar:
100                for i in range(args.requests):
101                    taskhash = hashlib.sha256()
102                    taskhash.update(args.taskhash_seed.encode('utf-8'))
103                    taskhash.update(str(i).encode('utf-8'))
104
105                    outhash = hashlib.sha256()
106                    outhash.update(args.outhash_seed.encode('utf-8'))
107                    outhash.update(str(i).encode('utf-8'))
108
109                    client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest())
110
111                    with lock:
112                        pbar.update()
113
114    parser = argparse.ArgumentParser(description='Hash Equivalence Client')
115    parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
116    parser.add_argument('--log', default='WARNING', help='Set logging level')
117
118    subparsers = parser.add_subparsers()
119
120    stats_parser = subparsers.add_parser('stats', help='Show server stats')
121    stats_parser.add_argument('--reset', action='store_true',
122                              help='Reset server stats')
123    stats_parser.set_defaults(func=handle_stats)
124
125    stress_parser = subparsers.add_parser('stress', help='Run stress test')
126    stress_parser.add_argument('--clients', type=int, default=10,
127                               help='Number of simultaneous clients')
128    stress_parser.add_argument('--requests', type=int, default=1000,
129                               help='Number of requests each client will perform')
130    stress_parser.add_argument('--report', action='store_true',
131                               help='Report new hashes')
132    stress_parser.add_argument('--taskhash-seed', default='',
133                               help='Include string in taskhash')
134    stress_parser.add_argument('--outhash-seed', default='',
135                               help='Include string in outhash')
136    stress_parser.set_defaults(func=handle_stress)
137
138    args = parser.parse_args()
139
140    logger = logging.getLogger('hashserv')
141
142    level = getattr(logging, args.log.upper(), None)
143    if not isinstance(level, int):
144        raise ValueError('Invalid log level: %s' % args.log)
145
146    logger.setLevel(level)
147    console = logging.StreamHandler()
148    console.setLevel(level)
149    logger.addHandler(console)
150
151    func = getattr(args, 'func', None)
152    if func:
153        client = hashserv.create_client(args.address)
154
155        return func(args, client)
156
157    return 0
158
159
160if __name__ == '__main__':
161    try:
162        ret = main()
163    except Exception:
164        ret = 1
165        import traceback
166        traceback.print_exc()
167    sys.exit(ret)
168