1#! /usr/bin/env python3
2#
3# Copyright (C) 2018-2019 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8from . import create_server, create_client
9from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS
10from bb.asyncrpc import InvokeError
11from .client import ClientPool
12import hashlib
13import logging
14import multiprocessing
15import os
16import sys
17import tempfile
18import threading
19import unittest
20import socket
21import time
22import signal
23import subprocess
24import json
25import re
26from pathlib import Path
27
28
29THIS_DIR = Path(__file__).parent
30BIN_DIR = THIS_DIR.parent.parent / "bin"
31
32def server_prefunc(server, idx):
33    logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w',
34                        format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
35    server.logger.debug("Running server %d" % idx)
36    sys.stdout = open('bbhashserv-stdout-%d.log' % idx, 'w')
37    sys.stderr = sys.stdout
38
39class HashEquivalenceTestSetup(object):
40    METHOD = 'TestMethod'
41
42    server_index = 0
43    client_index = 0
44
45    def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None):
46        self.server_index += 1
47        if dbpath is None:
48            dbpath = self.make_dbpath()
49
50        def cleanup_server(server):
51            if server.process.exitcode is not None:
52                return
53
54            server.process.terminate()
55            server.process.join()
56
57        server = create_server(self.get_server_addr(self.server_index),
58                               dbpath,
59                               upstream=upstream,
60                               read_only=read_only,
61                               anon_perms=anon_perms,
62                               admin_username=admin_username,
63                               admin_password=admin_password)
64        server.dbpath = dbpath
65
66        server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
67        self.addCleanup(cleanup_server, server)
68
69        return server
70
71    def make_dbpath(self):
72        return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
73
74    def start_client(self, server_address, username=None, password=None):
75        def cleanup_client(client):
76            client.close()
77
78        client = create_client(server_address, username=username, password=password)
79        self.addCleanup(cleanup_client, client)
80
81        return client
82
83    def start_test_server(self):
84        self.server = self.start_server()
85        return self.server.address
86
87    def start_auth_server(self):
88        auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
89        self.auth_server_address = auth_server.address
90        self.admin_client = self.start_client(auth_server.address, username="admin", password="password")
91        return self.admin_client
92
93    def auth_client(self, user):
94        return self.start_client(self.auth_server_address, user["username"], user["token"])
95
96    def setUp(self):
97        if sys.version_info < (3, 5, 0):
98            self.skipTest('Python 3.5 or later required')
99
100        self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
101        self.addCleanup(self.temp_dir.cleanup)
102
103        self.server_address = self.start_test_server()
104
105        self.client = self.start_client(self.server_address)
106
107    def assertClientGetHash(self, client, taskhash, unihash):
108        result = client.get_unihash(self.METHOD, taskhash)
109        self.assertEqual(result, unihash)
110
111    def assertUserPerms(self, user, permissions):
112        with self.auth_client(user) as client:
113            info = client.get_user()
114            self.assertEqual(info, {
115                "username": user["username"],
116                "permissions": permissions,
117            })
118
119    def assertUserCanAuth(self, user):
120        with self.start_client(self.auth_server_address) as client:
121            client.auth(user["username"], user["token"])
122
123    def assertUserCannotAuth(self, user):
124        with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
125            client.auth(user["username"], user["token"])
126
127    def create_test_hash(self, client):
128        # Simple test that hashes can be created
129        taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
130        outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
131        unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
132
133        self.assertClientGetHash(client, taskhash, None)
134
135        result = client.report_unihash(taskhash, self.METHOD, outhash, unihash)
136        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
137        return taskhash, outhash, unihash
138
139    def run_hashclient(self, args, **kwargs):
140        try:
141            p = subprocess.run(
142                [BIN_DIR / "bitbake-hashclient"] + args,
143                stdout=subprocess.PIPE,
144                stderr=subprocess.STDOUT,
145                encoding="utf-8",
146                **kwargs
147            )
148        except subprocess.CalledProcessError as e:
149            print(e.output)
150            raise e
151
152        print(p.stdout)
153        return p
154
155
156class HashEquivalenceCommonTests(object):
157    def auth_perms(self, *permissions):
158        self.client_index += 1
159        user = self.create_user(f"user-{self.client_index}", permissions)
160        return self.auth_client(user)
161
162    def create_user(self, username, permissions, *, client=None):
163        def remove_user(username):
164            try:
165                self.admin_client.delete_user(username)
166            except bb.asyncrpc.InvokeError:
167                pass
168
169        if client is None:
170            client = self.admin_client
171
172        user = client.new_user(username, permissions)
173        self.addCleanup(remove_user, username)
174
175        return user
176
177    def test_create_hash(self):
178        return self.create_test_hash(self.client)
179
180    def test_create_equivalent(self):
181        # Tests that a second reported task with the same outhash will be
182        # assigned the same unihash
183        taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
184        outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
185        unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
186
187        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
188        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
189
190        # Report a different task with the same outhash. The returned unihash
191        # should match the first task
192        taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
193        unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
194        result = self.client.report_unihash(taskhash2, self.METHOD, outhash, unihash2)
195        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
196
197    def test_duplicate_taskhash(self):
198        # Tests that duplicate reports of the same taskhash with different
199        # outhash & unihash always return the unihash from the first reported
200        # taskhash
201        taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a'
202        outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e'
203        unihash = '218e57509998197d570e2c98512d0105985dffc9'
204        self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
205
206        self.assertClientGetHash(self.client, taskhash, unihash)
207
208        outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d'
209        unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'
210        self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2)
211
212        self.assertClientGetHash(self.client, taskhash, unihash)
213
214        outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
215        unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603'
216        self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3)
217
218        self.assertClientGetHash(self.client, taskhash, unihash)
219
220    def test_remove_taskhash(self):
221        taskhash, outhash, unihash = self.create_test_hash(self.client)
222        result = self.client.remove({"taskhash": taskhash})
223        self.assertGreater(result["count"], 0)
224        self.assertClientGetHash(self.client, taskhash, None)
225
226        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
227        self.assertIsNone(result_outhash)
228
229    def test_remove_unihash(self):
230        taskhash, outhash, unihash = self.create_test_hash(self.client)
231        result = self.client.remove({"unihash": unihash})
232        self.assertGreater(result["count"], 0)
233        self.assertClientGetHash(self.client, taskhash, None)
234
235    def test_remove_outhash(self):
236        taskhash, outhash, unihash = self.create_test_hash(self.client)
237        result = self.client.remove({"outhash": outhash})
238        self.assertGreater(result["count"], 0)
239
240        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
241        self.assertIsNone(result_outhash)
242
243    def test_remove_method(self):
244        taskhash, outhash, unihash = self.create_test_hash(self.client)
245        result = self.client.remove({"method": self.METHOD})
246        self.assertGreater(result["count"], 0)
247        self.assertClientGetHash(self.client, taskhash, None)
248
249        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
250        self.assertIsNone(result_outhash)
251
252    def test_clean_unused(self):
253        taskhash, outhash, unihash = self.create_test_hash(self.client)
254
255        # Clean the database, which should not remove anything because all hashes an in-use
256        result = self.client.clean_unused(0)
257        self.assertEqual(result["count"], 0)
258        self.assertClientGetHash(self.client, taskhash, unihash)
259
260        # Remove the unihash. The row in the outhash table should still be present
261        self.client.remove({"unihash": unihash})
262        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
263        self.assertIsNotNone(result_outhash)
264
265        # Now clean with no minimum age which will remove the outhash
266        result = self.client.clean_unused(0)
267        self.assertEqual(result["count"], 1)
268        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
269        self.assertIsNone(result_outhash)
270
271    def test_huge_message(self):
272        # Simple test that hashes can be created
273        taskhash = 'c665584ee6817aa99edfc77a44dd853828279370'
274        outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
275        unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
276
277        self.assertClientGetHash(self.client, taskhash, None)
278
279        siginfo = "0" * (self.client.max_chunk * 4)
280
281        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash, {
282            'outhash_siginfo': siginfo
283        })
284        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
285
286        result_unihash = self.client.get_taskhash(self.METHOD, taskhash, True)
287        self.assertEqual(result_unihash['taskhash'], taskhash)
288        self.assertEqual(result_unihash['unihash'], unihash)
289        self.assertEqual(result_unihash['method'], self.METHOD)
290
291        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
292        self.assertEqual(result_outhash['taskhash'], taskhash)
293        self.assertEqual(result_outhash['method'], self.METHOD)
294        self.assertEqual(result_outhash['unihash'], unihash)
295        self.assertEqual(result_outhash['outhash'], outhash)
296        self.assertEqual(result_outhash['outhash_siginfo'], siginfo)
297
298    def test_stress(self):
299        def query_server(failures):
300            client = Client(self.server_address)
301            try:
302                for i in range(1000):
303                    taskhash = hashlib.sha256()
304                    taskhash.update(str(i).encode('utf-8'))
305                    taskhash = taskhash.hexdigest()
306                    result = client.get_unihash(self.METHOD, taskhash)
307                    if result != taskhash:
308                        failures.append("taskhash mismatch: %s != %s" % (result, taskhash))
309            finally:
310                client.close()
311
312        # Report hashes
313        for i in range(1000):
314            taskhash = hashlib.sha256()
315            taskhash.update(str(i).encode('utf-8'))
316            taskhash = taskhash.hexdigest()
317            self.client.report_unihash(taskhash, self.METHOD, taskhash, taskhash)
318
319        failures = []
320        threads = [threading.Thread(target=query_server, args=(failures,)) for t in range(100)]
321
322        for t in threads:
323            t.start()
324
325        for t in threads:
326            t.join()
327
328        self.assertFalse(failures)
329
330    def test_upstream_server(self):
331        # Tests upstream server support. This is done by creating two servers
332        # that share a database file. The downstream server has it upstream
333        # set to the test server, whereas the side server doesn't. This allows
334        # verification that the hash requests are being proxied to the upstream
335        # server by verifying that they appear on the downstream client, but not
336        # the side client. It also verifies that the results are pulled into
337        # the downstream database by checking that the downstream and side servers
338        # match after the downstream is done waiting for all backfill tasks
339        down_server = self.start_server(upstream=self.server_address)
340        down_client = self.start_client(down_server.address)
341        side_server = self.start_server(dbpath=down_server.dbpath)
342        side_client = self.start_client(side_server.address)
343
344        def check_hash(taskhash, unihash, old_sidehash):
345            nonlocal down_client
346            nonlocal side_client
347
348            # check upstream server
349            self.assertClientGetHash(self.client, taskhash, unihash)
350
351            # Hash should *not* be present on the side server
352            self.assertClientGetHash(side_client, taskhash, old_sidehash)
353
354            # Hash should be present on the downstream server, since it
355            # will defer to the upstream server. This will trigger
356            # the backfill in the downstream server
357            self.assertClientGetHash(down_client, taskhash, unihash)
358
359            # After waiting for the downstream client to finish backfilling the
360            # task from the upstream server, it should appear in the side server
361            # since the database is populated
362            down_client.backfill_wait()
363            self.assertClientGetHash(side_client, taskhash, unihash)
364
365        # Basic report
366        taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a'
367        outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e'
368        unihash = '218e57509998197d570e2c98512d0105985dffc9'
369        self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
370
371        check_hash(taskhash, unihash, None)
372
373        # Duplicated taskhash with multiple output hashes and unihashes.
374        # All servers should agree with the originally reported hash
375        outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d'
376        unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'
377        self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2)
378
379        check_hash(taskhash, unihash, unihash)
380
381        # Report an equivalent task. The sideload will originally report
382        # no unihash until backfilled
383        taskhash3 = "044c2ec8aaf480685a00ff6ff49e6162e6ad34e1"
384        unihash3 = "def64766090d28f627e816454ed46894bb3aab36"
385        self.client.report_unihash(taskhash3, self.METHOD, outhash, unihash3)
386
387        check_hash(taskhash3, unihash, None)
388
389        # Test that reporting a unihash in the downstream client isn't
390        # propagating to the upstream server
391        taskhash4 = "e3da00593d6a7fb435c7e2114976c59c5fd6d561"
392        outhash4 = "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a"
393        unihash4 = "3b5d3d83f07f259e9086fcb422c855286e18a57d"
394        down_client.report_unihash(taskhash4, self.METHOD, outhash4, unihash4)
395        down_client.backfill_wait()
396
397        self.assertClientGetHash(down_client, taskhash4, unihash4)
398        self.assertClientGetHash(side_client, taskhash4, unihash4)
399        self.assertClientGetHash(self.client, taskhash4, None)
400
401        # Test that reporting a unihash in the downstream is able to find a
402        # match which was previously reported to the upstream server
403        taskhash5 = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
404        outhash5 = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
405        unihash5 = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
406        result = self.client.report_unihash(taskhash5, self.METHOD, outhash5, unihash5)
407
408        taskhash6 = '35788efcb8dfb0a02659d81cf2bfd695fb30fafa'
409        unihash6 = 'f46d3fbb439bd9b921095da657a4de906510d2ce'
410        result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6)
411        self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream')
412
413        # Tests read through from server with
414        taskhash7 = '9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74'
415        outhash7 = '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69'
416        unihash7 = '05d2a63c81e32f0a36542ca677e8ad852365c538'
417        self.client.report_unihash(taskhash7, self.METHOD, outhash7, unihash7)
418
419        result = down_client.get_taskhash(self.METHOD, taskhash7, True)
420        self.assertEqual(result['unihash'], unihash7, 'Server failed to copy unihash from upstream')
421        self.assertEqual(result['outhash'], outhash7, 'Server failed to copy unihash from upstream')
422        self.assertEqual(result['taskhash'], taskhash7, 'Server failed to copy unihash from upstream')
423        self.assertEqual(result['method'], self.METHOD)
424
425        taskhash8 = '86978a4c8c71b9b487330b0152aade10c1ee58aa'
426        outhash8 = 'ca8c128e9d9e4a28ef24d0508aa20b5cf880604eacd8f65c0e366f7e0cc5fbcf'
427        unihash8 = 'd8bcf25369d40590ad7d08c84d538982f2023e01'
428        self.client.report_unihash(taskhash8, self.METHOD, outhash8, unihash8)
429
430        result = down_client.get_outhash(self.METHOD, outhash8, taskhash8)
431        self.assertEqual(result['unihash'], unihash8, 'Server failed to copy unihash from upstream')
432        self.assertEqual(result['outhash'], outhash8, 'Server failed to copy unihash from upstream')
433        self.assertEqual(result['taskhash'], taskhash8, 'Server failed to copy unihash from upstream')
434        self.assertEqual(result['method'], self.METHOD)
435
436        taskhash9 = 'ae6339531895ddf5b67e663e6a374ad8ec71d81c'
437        outhash9 = 'afc78172c81880ae10a1fec994b5b4ee33d196a001a1b66212a15ebe573e00b5'
438        unihash9 = '6662e699d6e3d894b24408ff9a4031ef9b038ee8'
439        self.client.report_unihash(taskhash9, self.METHOD, outhash9, unihash9)
440
441        result = down_client.get_taskhash(self.METHOD, taskhash9, False)
442        self.assertEqual(result['unihash'], unihash9, 'Server failed to copy unihash from upstream')
443        self.assertEqual(result['taskhash'], taskhash9, 'Server failed to copy unihash from upstream')
444        self.assertEqual(result['method'], self.METHOD)
445
446    def test_unihash_exsits(self):
447        taskhash, outhash, unihash = self.create_test_hash(self.client)
448        self.assertTrue(self.client.unihash_exists(unihash))
449        self.assertFalse(self.client.unihash_exists('6662e699d6e3d894b24408ff9a4031ef9b038ee8'))
450
451    def test_ro_server(self):
452        rw_server = self.start_server()
453        rw_client = self.start_client(rw_server.address)
454
455        ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True)
456        ro_client = self.start_client(ro_server.address)
457
458        # Report a hash via the read-write server
459        taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
460        outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
461        unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
462
463        result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash)
464        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
465
466        # Check the hash via the read-only server
467        self.assertClientGetHash(ro_client, taskhash, unihash)
468
469        # Ensure that reporting via the read-only server fails
470        taskhash2 = 'c665584ee6817aa99edfc77a44dd853828279370'
471        outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
472        unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
473
474        result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
475        self.assertEqual(result['unihash'], unihash2)
476
477        # Ensure that the database was not modified
478        self.assertClientGetHash(rw_client, taskhash2, None)
479
480
481    def test_slow_server_start(self):
482        # Ensures that the server will exit correctly even if it gets a SIGTERM
483        # before entering the main loop
484
485        event = multiprocessing.Event()
486
487        def prefunc(server, idx):
488            nonlocal event
489            server_prefunc(server, idx)
490            event.wait()
491
492        def do_nothing(signum, frame):
493            pass
494
495        old_signal = signal.signal(signal.SIGTERM, do_nothing)
496        self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
497
498        server = self.start_server(prefunc=prefunc)
499        server.process.terminate()
500        time.sleep(30)
501        event.set()
502        server.process.join(300)
503        self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
504
505    def test_diverging_report_race(self):
506        # Tests that a reported task will correctly pick up an updated unihash
507
508        # This is a baseline report added to the database to ensure that there
509        # is something to match against as equivalent
510        outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa'
511        taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
512        unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
513        result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1)
514
515        # Add a report that is equivalent to Task 1. It should ignore the
516        # provided unihash and report the unihash from task 1
517        taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273'
518        unihash2 = taskhash2
519        result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2)
520        self.assertEqual(result['unihash'], unihash1)
521
522        # Add another report for Task 2, but with a different outhash (e.g. the
523        # task is non-deterministic). It should still be marked with the Task 1
524        # unihash because it has the Task 2 taskhash, which is equivalent to
525        # Task 1
526        outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c'
527        result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2)
528        self.assertEqual(result['unihash'], unihash1)
529
530
531    def test_diverging_report_reverse_race(self):
532        # Same idea as the previous test, but Tasks 2 and 3 are reported in
533        # reverse order the opposite order
534
535        outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa'
536        taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
537        unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
538        result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1)
539
540        taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273'
541        unihash2 = taskhash2
542
543        # Report Task 3 first. Since there is nothing else in the database it
544        # will use the client provided unihash
545        outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c'
546        result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2)
547        self.assertEqual(result['unihash'], unihash2)
548
549        # Report Task 2. This is equivalent to Task 1 but there is already a mapping for
550        # taskhash2 so it will report unihash2
551        result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2)
552        self.assertEqual(result['unihash'], unihash2)
553
554        # The originally reported unihash for Task 3 should be unchanged even if it
555        # shares a taskhash with Task 2
556        self.assertClientGetHash(self.client, taskhash2, unihash2)
557
558
559    def test_client_pool_get_unihashes(self):
560        TEST_INPUT = (
561            # taskhash                                   outhash                                                            unihash
562            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
563            # Duplicated taskhash with multiple output hashes and unihashes.
564            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
565            # Equivalent hash
566            ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
567            ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
568            ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
569            ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
570            ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
571        )
572        EXTRA_QUERIES = (
573            "6b6be7a84ab179b4240c4302518dc3f6",
574        )
575
576        with ClientPool(self.server_address, 10) as client_pool:
577            for taskhash, outhash, unihash in TEST_INPUT:
578                self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
579
580            query = {idx: (self.METHOD, data[0]) for idx, data in enumerate(TEST_INPUT)}
581            for idx, taskhash in enumerate(EXTRA_QUERIES):
582                query[idx + len(TEST_INPUT)] = (self.METHOD, taskhash)
583
584            result = client_pool.get_unihashes(query)
585
586            self.assertDictEqual(result, {
587                0: "218e57509998197d570e2c98512d0105985dffc9",
588                1: "218e57509998197d570e2c98512d0105985dffc9",
589                2: "218e57509998197d570e2c98512d0105985dffc9",
590                3: "3b5d3d83f07f259e9086fcb422c855286e18a57d",
591                4: "f46d3fbb439bd9b921095da657a4de906510d2cd",
592                5: "f46d3fbb439bd9b921095da657a4de906510d2cd",
593                6: "05d2a63c81e32f0a36542ca677e8ad852365c538",
594                7: None,
595            })
596
597    def test_client_pool_unihash_exists(self):
598        TEST_INPUT = (
599            # taskhash                                   outhash                                                            unihash
600            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
601            # Duplicated taskhash with multiple output hashes and unihashes.
602            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
603            # Equivalent hash
604            ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
605            ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
606            ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
607            ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
608            ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
609        )
610        EXTRA_QUERIES = (
611            "6b6be7a84ab179b4240c4302518dc3f6",
612        )
613
614        result_unihashes = set()
615
616
617        with ClientPool(self.server_address, 10) as client_pool:
618            for taskhash, outhash, unihash in TEST_INPUT:
619                result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
620                result_unihashes.add(result["unihash"])
621
622            query = {}
623            expected = {}
624
625            for _, _, unihash in TEST_INPUT:
626                idx = len(query)
627                query[idx] = unihash
628                expected[idx] = unihash in result_unihashes
629
630
631            for unihash in EXTRA_QUERIES:
632                idx = len(query)
633                query[idx] = unihash
634                expected[idx] = False
635
636            result = client_pool.unihashes_exist(query)
637            self.assertDictEqual(result, expected)
638
639
640    def test_auth_read_perms(self):
641        admin_client = self.start_auth_server()
642
643        # Create hashes with non-authenticated server
644        taskhash, outhash, unihash = self.create_test_hash(self.client)
645
646        # Validate hash can be retrieved using authenticated client
647        with self.auth_perms("@read") as client:
648            self.assertClientGetHash(client, taskhash, unihash)
649
650        with self.auth_perms() as client, self.assertRaises(InvokeError):
651            self.assertClientGetHash(client, taskhash, unihash)
652
653    def test_auth_report_perms(self):
654        admin_client = self.start_auth_server()
655
656        # Without read permission, the user is completely denied
657        with self.auth_perms() as client, self.assertRaises(InvokeError):
658            self.create_test_hash(client)
659
660        # Read permission allows the call to succeed, but it doesn't record
661        # anythin in the database
662        with self.auth_perms("@read") as client:
663            taskhash, outhash, unihash = self.create_test_hash(client)
664            self.assertClientGetHash(client, taskhash, None)
665
666        # Report permission alone is insufficient
667        with self.auth_perms("@report") as client, self.assertRaises(InvokeError):
668            self.create_test_hash(client)
669
670        # Read and report permission actually modify the database
671        with self.auth_perms("@read", "@report") as client:
672            taskhash, outhash, unihash = self.create_test_hash(client)
673            self.assertClientGetHash(client, taskhash, unihash)
674
675    def test_auth_no_token_refresh_from_anon_user(self):
676        self.start_auth_server()
677
678        with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
679            client.refresh_token()
680
681    def test_auth_self_token_refresh(self):
682        admin_client = self.start_auth_server()
683
684        # Create a new user with no permissions
685        user = self.create_user("test-user", [])
686
687        with self.auth_client(user) as client:
688            new_user = client.refresh_token()
689
690        self.assertEqual(user["username"], new_user["username"])
691        self.assertNotEqual(user["token"], new_user["token"])
692        self.assertUserCanAuth(new_user)
693        self.assertUserCannotAuth(user)
694
695        # Explicitly specifying with your own username is fine also
696        with self.auth_client(new_user) as client:
697            new_user2 = client.refresh_token(user["username"])
698
699        self.assertEqual(user["username"], new_user2["username"])
700        self.assertNotEqual(user["token"], new_user2["token"])
701        self.assertUserCanAuth(new_user2)
702        self.assertUserCannotAuth(new_user)
703        self.assertUserCannotAuth(user)
704
705    def test_auth_token_refresh(self):
706        admin_client = self.start_auth_server()
707
708        user = self.create_user("test-user", [])
709
710        with self.auth_perms() as client, self.assertRaises(InvokeError):
711            client.refresh_token(user["username"])
712
713        with self.auth_perms("@user-admin") as client:
714            new_user = client.refresh_token(user["username"])
715
716        self.assertEqual(user["username"], new_user["username"])
717        self.assertNotEqual(user["token"], new_user["token"])
718        self.assertUserCanAuth(new_user)
719        self.assertUserCannotAuth(user)
720
721    def test_auth_self_get_user(self):
722        admin_client = self.start_auth_server()
723
724        user = self.create_user("test-user", [])
725        user_info = user.copy()
726        del user_info["token"]
727
728        with self.auth_client(user) as client:
729            info = client.get_user()
730            self.assertEqual(info, user_info)
731
732            # Explicitly asking for your own username is fine also
733            info = client.get_user(user["username"])
734            self.assertEqual(info, user_info)
735
736    def test_auth_get_user(self):
737        admin_client = self.start_auth_server()
738
739        user = self.create_user("test-user", [])
740        user_info = user.copy()
741        del user_info["token"]
742
743        with self.auth_perms() as client, self.assertRaises(InvokeError):
744            client.get_user(user["username"])
745
746        with self.auth_perms("@user-admin") as client:
747            info = client.get_user(user["username"])
748            self.assertEqual(info, user_info)
749
750            info = client.get_user("nonexist-user")
751            self.assertIsNone(info)
752
753    def test_auth_reconnect(self):
754        admin_client = self.start_auth_server()
755
756        user = self.create_user("test-user", [])
757        user_info = user.copy()
758        del user_info["token"]
759
760        with self.auth_client(user) as client:
761            info = client.get_user()
762            self.assertEqual(info, user_info)
763
764            client.disconnect()
765
766            info = client.get_user()
767            self.assertEqual(info, user_info)
768
769    def test_auth_delete_user(self):
770        admin_client = self.start_auth_server()
771
772        user = self.create_user("test-user", [])
773
774        # self service
775        with self.auth_client(user) as client:
776            client.delete_user(user["username"])
777
778        self.assertIsNone(admin_client.get_user(user["username"]))
779        user = self.create_user("test-user", [])
780
781        with self.auth_perms() as client, self.assertRaises(InvokeError):
782            client.delete_user(user["username"])
783
784        with self.auth_perms("@user-admin") as client:
785            client.delete_user(user["username"])
786
787        # User doesn't exist, so even though the permission is correct, it's an
788        # error
789        with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
790            client.delete_user(user["username"])
791
792    def test_auth_set_user_perms(self):
793        admin_client = self.start_auth_server()
794
795        user = self.create_user("test-user", [])
796
797        self.assertUserPerms(user, [])
798
799        # No self service to change permissions
800        with self.auth_client(user) as client, self.assertRaises(InvokeError):
801            client.set_user_perms(user["username"], ["@all"])
802        self.assertUserPerms(user, [])
803
804        with self.auth_perms() as client, self.assertRaises(InvokeError):
805            client.set_user_perms(user["username"], ["@all"])
806        self.assertUserPerms(user, [])
807
808        with self.auth_perms("@user-admin") as client:
809            client.set_user_perms(user["username"], ["@all"])
810        self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
811
812        # Bad permissions
813        with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
814            client.set_user_perms(user["username"], ["@this-is-not-a-permission"])
815        self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
816
817    def test_auth_get_all_users(self):
818        admin_client = self.start_auth_server()
819
820        user = self.create_user("test-user", [])
821
822        with self.auth_client(user) as client, self.assertRaises(InvokeError):
823            client.get_all_users()
824
825        # Give the test user the correct permission
826        admin_client.set_user_perms(user["username"], ["@user-admin"])
827
828        with self.auth_client(user) as client:
829            all_users = client.get_all_users()
830
831        # Convert to a dictionary for easier comparison
832        all_users = {u["username"]: u for u in all_users}
833
834        self.assertEqual(all_users,
835            {
836                "admin": {
837                    "username": "admin",
838                    "permissions": sorted(list(ALL_PERMISSIONS)),
839                },
840                "test-user": {
841                    "username": "test-user",
842                    "permissions": ["@user-admin"],
843                }
844            }
845        )
846
847    def test_auth_new_user(self):
848        self.start_auth_server()
849
850        permissions = ["@read", "@report", "@db-admin", "@user-admin"]
851        permissions.sort()
852
853        with self.auth_perms() as client, self.assertRaises(InvokeError):
854            self.create_user("test-user", permissions, client=client)
855
856        with self.auth_perms("@user-admin") as client:
857            user = self.create_user("test-user", permissions, client=client)
858            self.assertIn("token", user)
859            self.assertEqual(user["username"], "test-user")
860            self.assertEqual(user["permissions"], permissions)
861
862    def test_auth_become_user(self):
863        admin_client = self.start_auth_server()
864
865        user = self.create_user("test-user", ["@read", "@report"])
866        user_info = user.copy()
867        del user_info["token"]
868
869        with self.auth_perms() as client, self.assertRaises(InvokeError):
870            client.become_user(user["username"])
871
872        with self.auth_perms("@user-admin") as client:
873            become = client.become_user(user["username"])
874            self.assertEqual(become, user_info)
875
876            info = client.get_user()
877            self.assertEqual(info, user_info)
878
879            # Verify become user is preserved across disconnect
880            client.disconnect()
881
882            info = client.get_user()
883            self.assertEqual(info, user_info)
884
885            # test-user doesn't have become_user permissions, so this should
886            # not work
887            with self.assertRaises(InvokeError):
888                client.become_user(user["username"])
889
890        # No self-service of become
891        with self.auth_client(user) as client, self.assertRaises(InvokeError):
892            client.become_user(user["username"])
893
894        # Give test user permissions to become
895        admin_client.set_user_perms(user["username"], ["@user-admin"])
896
897        # It's possible to become yourself (effectively a noop)
898        with self.auth_perms("@user-admin") as client:
899            become = client.become_user(client.username)
900
901    def test_auth_gc(self):
902        admin_client = self.start_auth_server()
903
904        with self.auth_perms() as client, self.assertRaises(InvokeError):
905            client.gc_mark("ABC", {"unihash": "123"})
906
907        with self.auth_perms() as client, self.assertRaises(InvokeError):
908            client.gc_status()
909
910        with self.auth_perms() as client, self.assertRaises(InvokeError):
911            client.gc_sweep("ABC")
912
913        with self.auth_perms("@db-admin") as client:
914            client.gc_mark("ABC", {"unihash": "123"})
915
916        with self.auth_perms("@db-admin") as client:
917            client.gc_status()
918
919        with self.auth_perms("@db-admin") as client:
920            client.gc_sweep("ABC")
921
922    def test_get_db_usage(self):
923        usage = self.client.get_db_usage()
924
925        self.assertTrue(isinstance(usage, dict))
926        for name in usage.keys():
927            self.assertTrue(isinstance(usage[name], dict))
928            self.assertIn("rows", usage[name])
929            self.assertTrue(isinstance(usage[name]["rows"], int))
930
931    def test_get_db_query_columns(self):
932        columns = self.client.get_db_query_columns()
933
934        self.assertTrue(isinstance(columns, list))
935        self.assertTrue(len(columns) > 0)
936
937        for col in columns:
938            self.client.remove({col: ""})
939
940    def test_auth_is_owner(self):
941        admin_client = self.start_auth_server()
942
943        user = self.create_user("test-user", ["@read", "@report"])
944        with self.auth_client(user) as client:
945            taskhash, outhash, unihash = self.create_test_hash(client)
946            data = client.get_taskhash(self.METHOD, taskhash, True)
947            self.assertEqual(data["owner"], user["username"])
948
949    def test_gc(self):
950        taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
951        outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
952        unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
953
954        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
955        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
956
957        taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
958        outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
959        unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
960
961        result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
962        self.assertClientGetHash(self.client, taskhash2, unihash2)
963
964        # Mark the first unihash to be kept
965        ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
966        self.assertEqual(ret, {"count": 1})
967
968        ret = self.client.gc_status()
969        self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
970
971        # Second hash is still there; mark doesn't delete hashes
972        self.assertClientGetHash(self.client, taskhash2, unihash2)
973
974        ret = self.client.gc_sweep("ABC")
975        self.assertEqual(ret, {"count": 1})
976
977        # Hash is gone. Taskhash is returned for second hash
978        self.assertClientGetHash(self.client, taskhash2, None)
979        # First hash is still present
980        self.assertClientGetHash(self.client, taskhash, unihash)
981
982    def test_gc_switch_mark(self):
983        taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
984        outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
985        unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
986
987        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
988        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
989
990        taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
991        outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
992        unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
993
994        result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
995        self.assertClientGetHash(self.client, taskhash2, unihash2)
996
997        # Mark the first unihash to be kept
998        ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
999        self.assertEqual(ret, {"count": 1})
1000
1001        ret = self.client.gc_status()
1002        self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
1003
1004        # Second hash is still there; mark doesn't delete hashes
1005        self.assertClientGetHash(self.client, taskhash2, unihash2)
1006
1007        # Switch to a different mark and mark the second hash. This will start
1008        # a new collection cycle
1009        ret = self.client.gc_mark("DEF", {"unihash": unihash2, "method": self.METHOD})
1010        self.assertEqual(ret, {"count": 1})
1011
1012        ret = self.client.gc_status()
1013        self.assertEqual(ret, {"mark": "DEF", "keep": 1, "remove": 1})
1014
1015        # Both hashes are still present
1016        self.assertClientGetHash(self.client, taskhash2, unihash2)
1017        self.assertClientGetHash(self.client, taskhash, unihash)
1018
1019        # Sweep with the new mark
1020        ret = self.client.gc_sweep("DEF")
1021        self.assertEqual(ret, {"count": 1})
1022
1023        # First hash is gone, second is kept
1024        self.assertClientGetHash(self.client, taskhash2, unihash2)
1025        self.assertClientGetHash(self.client, taskhash, None)
1026
1027    def test_gc_switch_sweep_mark(self):
1028        taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1029        outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1030        unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1031
1032        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1033        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1034
1035        taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1036        outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1037        unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1038
1039        result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1040        self.assertClientGetHash(self.client, taskhash2, unihash2)
1041
1042        # Mark the first unihash to be kept
1043        ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
1044        self.assertEqual(ret, {"count": 1})
1045
1046        ret = self.client.gc_status()
1047        self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
1048
1049        # Sweeping with a different mark raises an error
1050        with self.assertRaises(InvokeError):
1051            self.client.gc_sweep("DEF")
1052
1053        # Both hashes are present
1054        self.assertClientGetHash(self.client, taskhash2, unihash2)
1055        self.assertClientGetHash(self.client, taskhash, unihash)
1056
1057    def test_gc_new_hashes(self):
1058        taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1059        outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1060        unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1061
1062        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1063        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1064
1065        # Start a new garbage collection
1066        ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
1067        self.assertEqual(ret, {"count": 1})
1068
1069        ret = self.client.gc_status()
1070        self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 0})
1071
1072        # Add second hash. It should inherit the mark from the current garbage
1073        # collection operation
1074
1075        taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1076        outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1077        unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1078
1079        result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1080        self.assertClientGetHash(self.client, taskhash2, unihash2)
1081
1082        # Sweep should remove nothing
1083        ret = self.client.gc_sweep("ABC")
1084        self.assertEqual(ret, {"count": 0})
1085
1086        # Both hashes are present
1087        self.assertClientGetHash(self.client, taskhash2, unihash2)
1088        self.assertClientGetHash(self.client, taskhash, unihash)
1089
1090
1091class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
1092    def get_server_addr(self, server_idx):
1093        return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
1094
1095    def test_get(self):
1096        taskhash, outhash, unihash = self.create_test_hash(self.client)
1097
1098        p = self.run_hashclient(["--address", self.server_address, "get", self.METHOD, taskhash])
1099        data = json.loads(p.stdout)
1100        self.assertEqual(data["unihash"], unihash)
1101        self.assertEqual(data["outhash"], outhash)
1102        self.assertEqual(data["taskhash"], taskhash)
1103        self.assertEqual(data["method"], self.METHOD)
1104
1105    def test_get_outhash(self):
1106        taskhash, outhash, unihash = self.create_test_hash(self.client)
1107
1108        p = self.run_hashclient(["--address", self.server_address, "get-outhash", self.METHOD, outhash, taskhash])
1109        data = json.loads(p.stdout)
1110        self.assertEqual(data["unihash"], unihash)
1111        self.assertEqual(data["outhash"], outhash)
1112        self.assertEqual(data["taskhash"], taskhash)
1113        self.assertEqual(data["method"], self.METHOD)
1114
1115    def test_stats(self):
1116        p = self.run_hashclient(["--address", self.server_address, "stats"], check=True)
1117        json.loads(p.stdout)
1118
1119    def test_stress(self):
1120        self.run_hashclient(["--address", self.server_address, "stress"], check=True)
1121
1122    def test_unihash_exsits(self):
1123        taskhash, outhash, unihash = self.create_test_hash(self.client)
1124
1125        p = self.run_hashclient([
1126            "--address", self.server_address,
1127            "unihash-exists", unihash,
1128        ], check=True)
1129        self.assertEqual(p.stdout.strip(), "true")
1130
1131        p = self.run_hashclient([
1132            "--address", self.server_address,
1133            "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8',
1134        ], check=True)
1135        self.assertEqual(p.stdout.strip(), "false")
1136
1137    def test_unihash_exsits_quiet(self):
1138        taskhash, outhash, unihash = self.create_test_hash(self.client)
1139
1140        p = self.run_hashclient([
1141            "--address", self.server_address,
1142            "unihash-exists", unihash,
1143            "--quiet",
1144        ])
1145        self.assertEqual(p.returncode, 0)
1146        self.assertEqual(p.stdout.strip(), "")
1147
1148        p = self.run_hashclient([
1149            "--address", self.server_address,
1150            "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8',
1151            "--quiet",
1152        ])
1153        self.assertEqual(p.returncode, 1)
1154        self.assertEqual(p.stdout.strip(), "")
1155
1156    def test_remove_taskhash(self):
1157        taskhash, outhash, unihash = self.create_test_hash(self.client)
1158        self.run_hashclient([
1159            "--address", self.server_address,
1160            "remove",
1161            "--where", "taskhash", taskhash,
1162        ], check=True)
1163        self.assertClientGetHash(self.client, taskhash, None)
1164
1165        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1166        self.assertIsNone(result_outhash)
1167
1168    def test_remove_unihash(self):
1169        taskhash, outhash, unihash = self.create_test_hash(self.client)
1170        self.run_hashclient([
1171            "--address", self.server_address,
1172            "remove",
1173            "--where", "unihash", unihash,
1174        ], check=True)
1175        self.assertClientGetHash(self.client, taskhash, None)
1176
1177    def test_remove_outhash(self):
1178        taskhash, outhash, unihash = self.create_test_hash(self.client)
1179        self.run_hashclient([
1180            "--address", self.server_address,
1181            "remove",
1182            "--where", "outhash", outhash,
1183        ], check=True)
1184
1185        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1186        self.assertIsNone(result_outhash)
1187
1188    def test_remove_method(self):
1189        taskhash, outhash, unihash = self.create_test_hash(self.client)
1190        self.run_hashclient([
1191            "--address", self.server_address,
1192            "remove",
1193            "--where", "method", self.METHOD,
1194        ], check=True)
1195        self.assertClientGetHash(self.client, taskhash, None)
1196
1197        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1198        self.assertIsNone(result_outhash)
1199
1200    def test_clean_unused(self):
1201        taskhash, outhash, unihash = self.create_test_hash(self.client)
1202
1203        # Clean the database, which should not remove anything because all hashes an in-use
1204        self.run_hashclient([
1205            "--address", self.server_address,
1206            "clean-unused", "0",
1207        ], check=True)
1208        self.assertClientGetHash(self.client, taskhash, unihash)
1209
1210        # Remove the unihash. The row in the outhash table should still be present
1211        self.run_hashclient([
1212            "--address", self.server_address,
1213            "remove",
1214            "--where", "unihash", unihash,
1215        ], check=True)
1216        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
1217        self.assertIsNotNone(result_outhash)
1218
1219        # Now clean with no minimum age which will remove the outhash
1220        self.run_hashclient([
1221            "--address", self.server_address,
1222            "clean-unused", "0",
1223        ], check=True)
1224        result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
1225        self.assertIsNone(result_outhash)
1226
1227    def test_refresh_token(self):
1228        admin_client = self.start_auth_server()
1229
1230        user = admin_client.new_user("test-user", ["@read", "@report"])
1231
1232        p = self.run_hashclient([
1233            "--address", self.auth_server_address,
1234            "--login", user["username"],
1235            "--password", user["token"],
1236            "refresh-token"
1237        ], check=True)
1238
1239        new_token = None
1240        for l in p.stdout.splitlines():
1241            l = l.rstrip()
1242            m = re.match(r'Token: +(.*)$', l)
1243            if m is not None:
1244                new_token = m.group(1)
1245
1246        self.assertTrue(new_token)
1247
1248        print("New token is %r" % new_token)
1249
1250        self.run_hashclient([
1251            "--address", self.auth_server_address,
1252            "--login", user["username"],
1253            "--password", new_token,
1254            "get-user"
1255        ], check=True)
1256
1257    def test_set_user_perms(self):
1258        admin_client = self.start_auth_server()
1259
1260        user = admin_client.new_user("test-user", ["@read"])
1261
1262        self.run_hashclient([
1263            "--address", self.auth_server_address,
1264            "--login", admin_client.username,
1265            "--password", admin_client.password,
1266            "set-user-perms",
1267            "-u", user["username"],
1268            "@read", "@report",
1269        ], check=True)
1270
1271        new_user = admin_client.get_user(user["username"])
1272
1273        self.assertEqual(set(new_user["permissions"]), {"@read", "@report"})
1274
1275    def test_get_user(self):
1276        admin_client = self.start_auth_server()
1277
1278        user = admin_client.new_user("test-user", ["@read"])
1279
1280        p = self.run_hashclient([
1281            "--address", self.auth_server_address,
1282            "--login", admin_client.username,
1283            "--password", admin_client.password,
1284            "get-user",
1285            "-u", user["username"],
1286        ], check=True)
1287
1288        self.assertIn("Username:", p.stdout)
1289        self.assertIn("Permissions:", p.stdout)
1290
1291        p = self.run_hashclient([
1292            "--address", self.auth_server_address,
1293            "--login", user["username"],
1294            "--password", user["token"],
1295            "get-user",
1296        ], check=True)
1297
1298        self.assertIn("Username:", p.stdout)
1299        self.assertIn("Permissions:", p.stdout)
1300
1301    def test_get_all_users(self):
1302        admin_client = self.start_auth_server()
1303
1304        admin_client.new_user("test-user1", ["@read"])
1305        admin_client.new_user("test-user2", ["@read"])
1306
1307        p = self.run_hashclient([
1308            "--address", self.auth_server_address,
1309            "--login", admin_client.username,
1310            "--password", admin_client.password,
1311            "get-all-users",
1312        ], check=True)
1313
1314        self.assertIn("admin", p.stdout)
1315        self.assertIn("test-user1", p.stdout)
1316        self.assertIn("test-user2", p.stdout)
1317
1318    def test_new_user(self):
1319        admin_client = self.start_auth_server()
1320
1321        p = self.run_hashclient([
1322            "--address", self.auth_server_address,
1323            "--login", admin_client.username,
1324            "--password", admin_client.password,
1325            "new-user",
1326            "-u", "test-user",
1327            "@read", "@report",
1328        ], check=True)
1329
1330        new_token = None
1331        for l in p.stdout.splitlines():
1332            l = l.rstrip()
1333            m = re.match(r'Token: +(.*)$', l)
1334            if m is not None:
1335                new_token = m.group(1)
1336
1337        self.assertTrue(new_token)
1338
1339        user = {
1340            "username": "test-user",
1341            "token": new_token,
1342        }
1343
1344        self.assertUserPerms(user, ["@read", "@report"])
1345
1346    def test_delete_user(self):
1347        admin_client = self.start_auth_server()
1348
1349        user = admin_client.new_user("test-user", ["@read"])
1350
1351        p = self.run_hashclient([
1352            "--address", self.auth_server_address,
1353            "--login", admin_client.username,
1354            "--password", admin_client.password,
1355            "delete-user",
1356            "-u", user["username"],
1357        ], check=True)
1358
1359        self.assertIsNone(admin_client.get_user(user["username"]))
1360
1361    def test_get_db_usage(self):
1362        p = self.run_hashclient([
1363            "--address", self.server_address,
1364            "get-db-usage",
1365        ], check=True)
1366
1367    def test_get_db_query_columns(self):
1368        p = self.run_hashclient([
1369            "--address", self.server_address,
1370            "get-db-query-columns",
1371        ], check=True)
1372
1373    def test_gc(self):
1374        taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1375        outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1376        unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1377
1378        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1379        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1380
1381        taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1382        outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1383        unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1384
1385        result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1386        self.assertClientGetHash(self.client, taskhash2, unihash2)
1387
1388        # Mark the first unihash to be kept
1389        self.run_hashclient([
1390            "--address", self.server_address,
1391            "gc-mark", "ABC",
1392            "--where", "unihash", unihash,
1393            "--where", "method", self.METHOD
1394        ], check=True)
1395
1396        # Second hash is still there; mark doesn't delete hashes
1397        self.assertClientGetHash(self.client, taskhash2, unihash2)
1398
1399        self.run_hashclient([
1400            "--address", self.server_address,
1401            "gc-sweep", "ABC",
1402        ], check=True)
1403
1404        # Hash is gone. Taskhash is returned for second hash
1405        self.assertClientGetHash(self.client, taskhash2, None)
1406        # First hash is still present
1407        self.assertClientGetHash(self.client, taskhash, unihash)
1408
1409
1410class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1411    def get_server_addr(self, server_idx):
1412        return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
1413
1414
1415class TestHashEquivalenceUnixServerLongPath(HashEquivalenceTestSetup, unittest.TestCase):
1416    DEEP_DIRECTORY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa/bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb/ccccccccccccccccccccccccccccccccccccccccccc"
1417    def get_server_addr(self, server_idx):
1418        os.makedirs(os.path.join(self.temp_dir.name, self.DEEP_DIRECTORY), exist_ok=True)
1419        return "unix://" + os.path.join(self.temp_dir.name, self.DEEP_DIRECTORY, 'sock%d' % server_idx)
1420
1421
1422    def test_long_sock_path(self):
1423        # Simple test that hashes can be created
1424        taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
1425        outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
1426        unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
1427
1428        self.assertClientGetHash(self.client, taskhash, None)
1429
1430        result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1431        self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1432
1433
1434class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1435    def get_server_addr(self, server_idx):
1436        # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
1437        # If IPv6 is enabled, it should be safe to use localhost directly, in general
1438        # case it is more reliable to resolve the IP address explicitly.
1439        return socket.gethostbyname("localhost") + ":0"
1440
1441
1442class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1443    def setUp(self):
1444        try:
1445            import websockets
1446        except ImportError as e:
1447            self.skipTest(str(e))
1448
1449        super().setUp()
1450
1451    def get_server_addr(self, server_idx):
1452        # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
1453        # If IPv6 is enabled, it should be safe to use localhost directly, in general
1454        # case it is more reliable to resolve the IP address explicitly.
1455        host = socket.gethostbyname("localhost")
1456        return "ws://%s:0" % host
1457
1458
1459class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer):
1460    def setUp(self):
1461        try:
1462            import sqlalchemy
1463            import aiosqlite
1464        except ImportError as e:
1465            self.skipTest(str(e))
1466
1467        super().setUp()
1468
1469    def make_dbpath(self):
1470        return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
1471
1472
1473class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1474    def get_env(self, name):
1475        v = os.environ.get(name)
1476        if not v:
1477            self.skipTest(f'{name} not defined to test an external server')
1478        return v
1479
1480    def start_test_server(self):
1481        return self.get_env('BB_TEST_HASHSERV')
1482
1483    def start_server(self, *args, **kwargs):
1484        self.skipTest('Cannot start local server when testing external servers')
1485
1486    def start_auth_server(self):
1487
1488        self.auth_server_address = self.server_address
1489        self.admin_client = self.start_client(
1490            self.server_address,
1491            username=self.get_env('BB_TEST_HASHSERV_USERNAME'),
1492            password=self.get_env('BB_TEST_HASHSERV_PASSWORD'),
1493        )
1494        return self.admin_client
1495
1496    def setUp(self):
1497        super().setUp()
1498        if "BB_TEST_HASHSERV_USERNAME" in os.environ:
1499            self.client = self.start_client(
1500                self.server_address,
1501                username=os.environ["BB_TEST_HASHSERV_USERNAME"],
1502                password=os.environ["BB_TEST_HASHSERV_PASSWORD"],
1503            )
1504        self.client.remove({"method": self.METHOD})
1505
1506    def tearDown(self):
1507        self.client.remove({"method": self.METHOD})
1508        super().tearDown()
1509
1510
1511    def test_auth_get_all_users(self):
1512        self.skipTest("Cannot test all users with external server")
1513
1514