xref: /openbmc/qemu/tests/functional/qemu_test/asset.py (revision 7ae004869aff46fc3195d280b25dc9b94a447be7)
1# Test utilities for fetching & caching assets
2#
3# Copyright 2024 Red Hat, Inc.
4#
5# This work is licensed under the terms of the GNU GPL, version 2 or
6# later.  See the COPYING file in the top-level directory.
7
8import hashlib
9import logging
10import os
11import stat
12import sys
13import time
14import unittest
15import urllib.request
16from time import sleep
17from pathlib import Path
18from shutil import copyfileobj
19from urllib.error import HTTPError, URLError
20
21class AssetError(Exception):
22    def __init__(self, asset, msg, transient=False):
23        self.url = asset.url
24        self.msg = msg
25        self.transient = transient
26
27    def __str__(self):
28        return "%s: %s" % (self.url, self.msg)
29
30# Instances of this class must be declared as class level variables
31# starting with a name "ASSET_". This enables the pre-caching logic
32# to easily find all referenced assets and download them prior to
33# execution of the tests.
34class Asset:
35
36    def __init__(self, url, hashsum):
37        self.url = url
38        self.hash = hashsum
39        cache_dir_env = os.getenv('QEMU_TEST_CACHE_DIR')
40        if cache_dir_env:
41            self.cache_dir = Path(cache_dir_env, "download")
42        else:
43            self.cache_dir = Path(Path("~").expanduser(),
44                                  ".cache", "qemu", "download")
45        self.cache_file = Path(self.cache_dir, hashsum)
46        self.log = logging.getLogger('qemu-test')
47
48    def __repr__(self):
49        return "Asset: url=%s hash=%s cache=%s" % (
50            self.url, self.hash, self.cache_file)
51
52    def __str__(self):
53        return str(self.cache_file)
54
55    def _check(self, cache_file):
56        if self.hash is None:
57            return True
58        if len(self.hash) == 64:
59            hl = hashlib.sha256()
60        elif len(self.hash) == 128:
61            hl = hashlib.sha512()
62        else:
63            raise AssetError(self, "unknown hash type")
64
65        # Calculate the hash of the file:
66        with open(cache_file, 'rb') as file:
67            while True:
68                chunk = file.read(1 << 20)
69                if not chunk:
70                    break
71                hl.update(chunk)
72
73        return self.hash == hl.hexdigest()
74
75    def valid(self):
76        if os.getenv("QEMU_TEST_REFRESH_CACHE", None) is not None:
77            self.log.info("Force refresh of asset %s", self.url)
78            return False
79
80        return self.cache_file.exists() and self._check(self.cache_file)
81
82    def fetchable(self):
83        return not os.environ.get("QEMU_TEST_NO_DOWNLOAD", False)
84
85    def available(self):
86        return self.valid() or self.fetchable()
87
88    def _wait_for_other_download(self, tmp_cache_file):
89        # Another thread already seems to download the asset, so wait until
90        # it is done, while also checking the size to see whether it is stuck
91        try:
92            current_size = tmp_cache_file.stat().st_size
93            new_size = current_size
94        except:
95            if os.path.exists(self.cache_file):
96                return True
97            raise
98        waittime = lastchange = 600
99        while waittime > 0:
100            sleep(1)
101            waittime -= 1
102            try:
103                new_size = tmp_cache_file.stat().st_size
104            except:
105                if os.path.exists(self.cache_file):
106                    return True
107                raise
108            if new_size != current_size:
109                lastchange = waittime
110                current_size = new_size
111            elif lastchange - waittime > 90:
112                return False
113
114        self.log.debug("Time out while waiting for %s!", tmp_cache_file)
115        raise TimeoutError(f"Time out while waiting for {tmp_cache_file}")
116
117    def _save_time_stamp(self):
118        '''
119        Update the time stamp of the asset in the cache. Unfortunately, we
120        cannot use the modification or access time of the asset file itself,
121        since e.g. the functional jobs in the gitlab CI reload the files
122        from the gitlab cache and thus always have recent file time stamps,
123        so we have to save our asset time stamp to a separate file instead.
124        '''
125        self.cache_file.with_suffix(".stamp").write_text(f"{int(time.time())}")
126
127    def fetch(self):
128        if not self.cache_dir.exists():
129            self.cache_dir.mkdir(parents=True, exist_ok=True)
130
131        if self.valid():
132            self.log.debug("Using cached asset %s for %s",
133                           self.cache_file, self.url)
134            self._save_time_stamp()
135            return str(self.cache_file)
136
137        if not self.fetchable():
138            raise AssetError(self,
139                             "Asset cache is invalid and downloads disabled")
140
141        self.log.info("Downloading %s to %s...", self.url, self.cache_file)
142        tmp_cache_file = self.cache_file.with_suffix(".download")
143
144        for _retries in range(3):
145            try:
146                with tmp_cache_file.open("xb") as dst:
147                    with urllib.request.urlopen(self.url) as resp:
148                        copyfileobj(resp, dst)
149                        length_hdr = resp.getheader("Content-Length")
150
151                # Verify downloaded file size against length metadata, if
152                # available.
153                if length_hdr is not None:
154                    length = int(length_hdr)
155                    fsize = tmp_cache_file.stat().st_size
156                    if fsize != length:
157                        self.log.error("Unable to download %s: "
158                                       "connection closed before "
159                                       "transfer complete (%d/%d)",
160                                       self.url, fsize, length)
161                        tmp_cache_file.unlink()
162                        continue
163                break
164            except FileExistsError:
165                self.log.debug("%s already exists, "
166                               "waiting for other thread to finish...",
167                               tmp_cache_file)
168                if self._wait_for_other_download(tmp_cache_file):
169                    return str(self.cache_file)
170                self.log.debug("%s seems to be stale, "
171                               "deleting and retrying download...",
172                               tmp_cache_file)
173                tmp_cache_file.unlink()
174                continue
175            except HTTPError as e:
176                tmp_cache_file.unlink()
177                self.log.error("Unable to download %s: HTTP error %d",
178                               self.url, e.code)
179                # Treat 404 as fatal, since it is highly likely to
180                # indicate a broken test rather than a transient
181                # server or networking problem
182                if e.code == 404:
183                    raise AssetError(self, "Unable to download: "
184                                     "HTTP error %d" % e.code) from e
185                continue
186            except URLError as e:
187                # This is typically a network/service level error
188                # eg urlopen error [Errno 110] Connection timed out>
189                tmp_cache_file.unlink()
190                self.log.error("Unable to download %s: URL error %s",
191                               self.url, e.reason)
192                raise AssetError(self, "Unable to download: URL error %s" %
193                                 e.reason, transient=True) from e
194            except ConnectionError as e:
195                # A socket connection failure, such as dropped conn
196                # or refused conn
197                tmp_cache_file.unlink()
198                self.log.error("Unable to download %s: Connection error %s",
199                               self.url, e)
200                continue
201            except Exception as e:
202                tmp_cache_file.unlink()
203                raise AssetError(self, "Unable to download: %s" % e,
204                                 transient=True) from e
205
206        if not os.path.exists(tmp_cache_file):
207            raise AssetError(self, "Download retries exceeded", transient=True)
208
209        try:
210            # Set these just for informational purposes
211            os.setxattr(str(tmp_cache_file), "user.qemu-asset-url",
212                        self.url.encode('utf8'))
213            os.setxattr(str(tmp_cache_file), "user.qemu-asset-hash",
214                        self.hash.encode('utf8'))
215        except Exception as e:
216            self.log.debug("Unable to set xattr on %s: %s", tmp_cache_file, e)
217
218        if not self._check(tmp_cache_file):
219            tmp_cache_file.unlink()
220            raise AssetError(self, "Hash does not match %s" % self.hash)
221        tmp_cache_file.replace(self.cache_file)
222        self._save_time_stamp()
223        # Remove write perms to stop tests accidentally modifying them
224        os.chmod(self.cache_file, stat.S_IRUSR | stat.S_IRGRP)
225
226        self.log.info("Cached %s at %s", self.url, self.cache_file)
227        return str(self.cache_file)
228
229    @staticmethod
230    def precache_test(test):
231        log = logging.getLogger('qemu-test')
232        log.setLevel(logging.DEBUG)
233        handler = logging.StreamHandler(sys.stdout)
234        handler.setLevel(logging.DEBUG)
235        formatter = logging.Formatter(
236            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
237        handler.setFormatter(formatter)
238        log.addHandler(handler)
239        for name, asset in vars(test.__class__).items():
240            if name.startswith("ASSET_") and isinstance(asset, Asset):
241                try:
242                    asset.fetch()
243                except AssetError as e:
244                    if not e.transient:
245                        raise
246                    log.error("%s: skipping asset precache", e)
247
248        log.removeHandler(handler)
249
250    @staticmethod
251    def precache_suite(suite):
252        for test in suite:
253            if isinstance(test, unittest.TestSuite):
254                Asset.precache_suite(test)
255            elif isinstance(test, unittest.TestCase):
256                Asset.precache_test(test)
257
258    @staticmethod
259    def precache_suites(path, cache_tstamp):
260        loader = unittest.loader.defaultTestLoader
261        tests = loader.loadTestsFromNames([path], None)
262
263        with open(cache_tstamp, "w", encoding='utf-8'):
264            Asset.precache_suite(tests)
265