1# Test utilities for fetching & caching assets
2#
3# Copyright 2024 Red Hat, Inc.
4#
5# This work is licensed under the terms of the GNU GPL, version 2 or
6# later.  See the COPYING file in the top-level directory.
7
8import hashlib
9import logging
10import os
11import stat
12import subprocess
13import sys
14import unittest
15import urllib.request
16from time import sleep
17from pathlib import Path
18from shutil import copyfileobj
19
20
21# Instances of this class must be declared as class level variables
22# starting with a name "ASSET_". This enables the pre-caching logic
23# to easily find all referenced assets and download them prior to
24# execution of the tests.
25class Asset:
26
27    def __init__(self, url, hashsum):
28        self.url = url
29        self.hash = hashsum
30        cache_dir_env = os.getenv('QEMU_TEST_CACHE_DIR')
31        if cache_dir_env:
32            self.cache_dir = Path(cache_dir_env, "download")
33        else:
34            self.cache_dir = Path(Path("~").expanduser(),
35                                  ".cache", "qemu", "download")
36        self.cache_file = Path(self.cache_dir, hashsum)
37        self.log = logging.getLogger('qemu-test')
38
39    def __repr__(self):
40        return "Asset: url=%s hash=%s cache=%s" % (
41            self.url, self.hash, self.cache_file)
42
43    def _check(self, cache_file):
44        if self.hash is None:
45            return True
46        if len(self.hash) == 64:
47            hl = hashlib.sha256()
48        elif len(self.hash) == 128:
49            hl = hashlib.sha512()
50        else:
51            raise Exception("unknown hash type")
52
53        # Calculate the hash of the file:
54        with open(cache_file, 'rb') as file:
55            while True:
56                chunk = file.read(1 << 20)
57                if not chunk:
58                    break
59                hl.update(chunk)
60
61        return self.hash == hl.hexdigest()
62
63    def valid(self):
64        return self.cache_file.exists() and self._check(self.cache_file)
65
66    def _wait_for_other_download(self, tmp_cache_file):
67        # Another thread already seems to download the asset, so wait until
68        # it is done, while also checking the size to see whether it is stuck
69        try:
70            current_size = tmp_cache_file.stat().st_size
71            new_size = current_size
72        except:
73            if os.path.exists(self.cache_file):
74                return True
75            raise
76        waittime = lastchange = 600
77        while waittime > 0:
78            sleep(1)
79            waittime -= 1
80            try:
81                new_size = tmp_cache_file.stat().st_size
82            except:
83                if os.path.exists(self.cache_file):
84                    return True
85                raise
86            if new_size != current_size:
87                lastchange = waittime
88                current_size = new_size
89            elif lastchange - waittime > 90:
90                return False
91
92        self.log.debug("Time out while waiting for %s!", tmp_cache_file)
93        raise
94
95    def fetch(self):
96        if not self.cache_dir.exists():
97            self.cache_dir.mkdir(parents=True, exist_ok=True)
98
99        if self.valid():
100            self.log.debug("Using cached asset %s for %s",
101                           self.cache_file, self.url)
102            return str(self.cache_file)
103
104        if os.environ.get("QEMU_TEST_NO_DOWNLOAD", False):
105            raise Exception("Asset cache is invalid and downloads disabled")
106
107        self.log.info("Downloading %s to %s...", self.url, self.cache_file)
108        tmp_cache_file = self.cache_file.with_suffix(".download")
109
110        for retries in range(3):
111            try:
112                with tmp_cache_file.open("xb") as dst:
113                    with urllib.request.urlopen(self.url) as resp:
114                        copyfileobj(resp, dst)
115                break
116            except FileExistsError:
117                self.log.debug("%s already exists, "
118                               "waiting for other thread to finish...",
119                               tmp_cache_file)
120                if self._wait_for_other_download(tmp_cache_file):
121                    return str(self.cache_file)
122                self.log.debug("%s seems to be stale, "
123                               "deleting and retrying download...",
124                               tmp_cache_file)
125                tmp_cache_file.unlink()
126                continue
127            except Exception as e:
128                self.log.error("Unable to download %s: %s", self.url, e)
129                tmp_cache_file.unlink()
130                raise
131
132        try:
133            # Set these just for informational purposes
134            os.setxattr(str(tmp_cache_file), "user.qemu-asset-url",
135                        self.url.encode('utf8'))
136            os.setxattr(str(tmp_cache_file), "user.qemu-asset-hash",
137                        self.hash.encode('utf8'))
138        except Exception as e:
139            self.log.debug("Unable to set xattr on %s: %s", tmp_cache_file, e)
140            pass
141
142        if not self._check(tmp_cache_file):
143            tmp_cache_file.unlink()
144            raise Exception("Hash of %s does not match %s" %
145                            (self.url, self.hash))
146        tmp_cache_file.replace(self.cache_file)
147        # Remove write perms to stop tests accidentally modifying them
148        os.chmod(self.cache_file, stat.S_IRUSR | stat.S_IRGRP)
149
150        self.log.info("Cached %s at %s" % (self.url, self.cache_file))
151        return str(self.cache_file)
152
153    def precache_test(test):
154        log = logging.getLogger('qemu-test')
155        log.setLevel(logging.DEBUG)
156        handler = logging.StreamHandler(sys.stdout)
157        handler.setLevel(logging.DEBUG)
158        formatter = logging.Formatter(
159            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
160        handler.setFormatter(formatter)
161        log.addHandler(handler)
162        for name, asset in vars(test.__class__).items():
163            if name.startswith("ASSET_") and type(asset) == Asset:
164                log.info("Attempting to cache '%s'" % asset)
165                asset.fetch()
166        log.removeHandler(handler)
167
168    def precache_suite(suite):
169        for test in suite:
170            if isinstance(test, unittest.TestSuite):
171                Asset.precache_suite(test)
172            elif isinstance(test, unittest.TestCase):
173                Asset.precache_test(test)
174
175    def precache_suites(path, cacheTstamp):
176        loader = unittest.loader.defaultTestLoader
177        tests = loader.loadTestsFromNames([path], None)
178
179        with open(cacheTstamp, "w") as fh:
180            Asset.precache_suite(tests)
181