1# Test utilities for fetching & caching assets
2#
3# Copyright 2024 Red Hat, Inc.
4#
5# This work is licensed under the terms of the GNU GPL, version 2 or
6# later.  See the COPYING file in the top-level directory.
7
8import hashlib
9import logging
10import os
11import subprocess
12import sys
13import unittest
14import urllib.request
15from time import sleep
16from pathlib import Path
17from shutil import copyfileobj
18
19
20# Instances of this class must be declared as class level variables
21# starting with a name "ASSET_". This enables the pre-caching logic
22# to easily find all referenced assets and download them prior to
23# execution of the tests.
24class Asset:
25
26    def __init__(self, url, hashsum):
27        self.url = url
28        self.hash = hashsum
29        cache_dir_env = os.getenv('QEMU_TEST_CACHE_DIR')
30        if cache_dir_env:
31            self.cache_dir = Path(cache_dir_env, "download")
32        else:
33            self.cache_dir = Path(Path("~").expanduser(),
34                                  ".cache", "qemu", "download")
35        self.cache_file = Path(self.cache_dir, hashsum)
36        self.log = logging.getLogger('qemu-test')
37
38    def __repr__(self):
39        return "Asset: url=%s hash=%s cache=%s" % (
40            self.url, self.hash, self.cache_file)
41
42    def _check(self, cache_file):
43        if self.hash is None:
44            return True
45        if len(self.hash) == 64:
46            hl = hashlib.sha256()
47        elif len(self.hash) == 128:
48            hl = hashlib.sha512()
49        else:
50            raise Exception("unknown hash type")
51
52        # Calculate the hash of the file:
53        with open(cache_file, 'rb') as file:
54            while True:
55                chunk = file.read(1 << 20)
56                if not chunk:
57                    break
58                hl.update(chunk)
59
60        return  hl.hexdigest()
61
62    def valid(self):
63        return self.cache_file.exists() and self._check(self.cache_file)
64
65    def _wait_for_other_download(self, tmp_cache_file):
66        # Another thread already seems to download the asset, so wait until
67        # it is done, while also checking the size to see whether it is stuck
68        try:
69            current_size = tmp_cache_file.stat().st_size
70            new_size = current_size
71        except:
72            if os.path.exists(self.cache_file):
73                return True
74            raise
75        waittime = lastchange = 600
76        while waittime > 0:
77            sleep(1)
78            waittime -= 1
79            try:
80                new_size = tmp_cache_file.stat().st_size
81            except:
82                if os.path.exists(self.cache_file):
83                    return True
84                raise
85            if new_size != current_size:
86                lastchange = waittime
87                current_size = new_size
88            elif lastchange - waittime > 90:
89                return False
90
91        self.log.debug("Time out while waiting for %s!", tmp_cache_file)
92        raise
93
94    def fetch(self):
95        if not self.cache_dir.exists():
96            self.cache_dir.mkdir(parents=True, exist_ok=True)
97
98        if self.valid():
99            self.log.debug("Using cached asset %s for %s",
100                           self.cache_file, self.url)
101            return str(self.cache_file)
102
103        if os.environ.get("QEMU_TEST_NO_DOWNLOAD", False):
104            raise Exception("Asset cache is invalid and downloads disabled")
105
106        self.log.info("Downloading %s to %s...", self.url, self.cache_file)
107        tmp_cache_file = self.cache_file.with_suffix(".download")
108
109        for retries in range(3):
110            try:
111                with tmp_cache_file.open("xb") as dst:
112                    with urllib.request.urlopen(self.url) as resp:
113                        copyfileobj(resp, dst)
114                break
115            except FileExistsError:
116                self.log.debug("%s already exists, "
117                               "waiting for other thread to finish...",
118                               tmp_cache_file)
119                if self._wait_for_other_download(tmp_cache_file):
120                    return str(self.cache_file)
121                self.log.debug("%s seems to be stale, "
122                               "deleting and retrying download...",
123                               tmp_cache_file)
124                tmp_cache_file.unlink()
125                continue
126            except Exception as e:
127                self.log.error("Unable to download %s: %s", self.url, e)
128                tmp_cache_file.unlink()
129                raise
130
131        try:
132            # Set these just for informational purposes
133            os.setxattr(str(tmp_cache_file), "user.qemu-asset-url",
134                        self.url.encode('utf8'))
135            os.setxattr(str(tmp_cache_file), "user.qemu-asset-hash",
136                        self.hash.encode('utf8'))
137        except Exception as e:
138            self.log.debug("Unable to set xattr on %s: %s", tmp_cache_file, e)
139            pass
140
141        if not self._check(tmp_cache_file):
142            tmp_cache_file.unlink()
143            raise Exception("Hash of %s does not match %s" %
144                            (self.url, self.hash))
145        tmp_cache_file.replace(self.cache_file)
146
147        self.log.info("Cached %s at %s" % (self.url, self.cache_file))
148        return str(self.cache_file)
149
150    def precache_test(test):
151        log = logging.getLogger('qemu-test')
152        log.setLevel(logging.DEBUG)
153        handler = logging.StreamHandler(sys.stdout)
154        handler.setLevel(logging.DEBUG)
155        formatter = logging.Formatter(
156            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
157        handler.setFormatter(formatter)
158        log.addHandler(handler)
159        for name, asset in vars(test.__class__).items():
160            if name.startswith("ASSET_") and type(asset) == Asset:
161                log.info("Attempting to cache '%s'" % asset)
162                asset.fetch()
163        log.removeHandler(handler)
164
165    def precache_suite(suite):
166        for test in suite:
167            if isinstance(test, unittest.TestSuite):
168                Asset.precache_suite(test)
169            elif isinstance(test, unittest.TestCase):
170                Asset.precache_test(test)
171
172    def precache_suites(path, cacheTstamp):
173        loader = unittest.loader.defaultTestLoader
174        tests = loader.loadTestsFromNames([path], None)
175
176        with open(cacheTstamp, "w") as fh:
177            Asset.precache_suite(tests)
178