1# Class for actually running tests.
2#
3# Copyright (c) 2020-2021 Virtuozzo International GmbH
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19import os
20from pathlib import Path
21import datetime
22import time
23import difflib
24import subprocess
25import contextlib
26import json
27import termios
28import shutil
29import sys
30from multiprocessing import Pool
31from contextlib import contextmanager
32from typing import List, Optional, Iterator, Any, Sequence, Dict, \
33        ContextManager
34
35from testenv import TestEnv
36
37
38def silent_unlink(path: Path) -> None:
39    try:
40        path.unlink()
41    except OSError:
42        pass
43
44
45def file_diff(file1: str, file2: str) -> List[str]:
46    with open(file1, encoding="utf-8") as f1, \
47         open(file2, encoding="utf-8") as f2:
48        # We want to ignore spaces at line ends. There are a lot of mess about
49        # it in iotests.
50        # TODO: fix all tests to not produce extra spaces, fix all .out files
51        # and use strict diff here!
52        seq1 = [line.rstrip() for line in f1]
53        seq2 = [line.rstrip() for line in f2]
54        res = [line.rstrip()
55               for line in difflib.unified_diff(seq1, seq2, file1, file2)]
56        return res
57
58
59# We want to save current tty settings during test run,
60# since an aborting qemu call may leave things screwed up.
61@contextmanager
62def savetty() -> Iterator[None]:
63    isterm = sys.stdin.isatty()
64    if isterm:
65        fd = sys.stdin.fileno()
66        attr = termios.tcgetattr(fd)
67
68    try:
69        yield
70    finally:
71        if isterm:
72            termios.tcsetattr(fd, termios.TCSADRAIN, attr)
73
74
75class LastElapsedTime(ContextManager['LastElapsedTime']):
76    """ Cache for elapsed time for tests, to show it during new test run
77
78    It is safe to use get() at any time.  To use update(), you must either
79    use it inside with-block or use save() after update().
80    """
81    def __init__(self, cache_file: str, env: TestEnv) -> None:
82        self.env = env
83        self.cache_file = cache_file
84        self.cache: Dict[str, Dict[str, Dict[str, float]]]
85
86        try:
87            with open(cache_file, encoding="utf-8") as f:
88                self.cache = json.load(f)
89        except (OSError, ValueError):
90            self.cache = {}
91
92    def get(self, test: str,
93            default: Optional[float] = None) -> Optional[float]:
94        if test not in self.cache:
95            return default
96
97        if self.env.imgproto not in self.cache[test]:
98            return default
99
100        return self.cache[test][self.env.imgproto].get(self.env.imgfmt,
101                                                       default)
102
103    def update(self, test: str, elapsed: float) -> None:
104        d = self.cache.setdefault(test, {})
105        d.setdefault(self.env.imgproto, {})[self.env.imgfmt] = elapsed
106
107    def save(self) -> None:
108        with open(self.cache_file, 'w', encoding="utf-8") as f:
109            json.dump(self.cache, f)
110
111    def __enter__(self) -> 'LastElapsedTime':
112        return self
113
114    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
115        self.save()
116
117
118class TestResult:
119    def __init__(self, status: str, description: str = '',
120                 elapsed: Optional[float] = None, diff: Sequence[str] = (),
121                 casenotrun: str = '', interrupted: bool = False) -> None:
122        self.status = status
123        self.description = description
124        self.elapsed = elapsed
125        self.diff = diff
126        self.casenotrun = casenotrun
127        self.interrupted = interrupted
128
129
130class TestRunner(ContextManager['TestRunner']):
131    shared_self = None
132
133    @staticmethod
134    def proc_run_test(test: str, test_field_width: int) -> TestResult:
135        # We are in a subprocess, we can't change the runner object!
136        runner = TestRunner.shared_self
137        assert runner is not None
138        return runner.run_test(test, test_field_width, mp=True)
139
140    def run_tests_pool(self, tests: List[str],
141                       test_field_width: int, jobs: int) -> List[TestResult]:
142
143        # passing self directly to Pool.starmap() just doesn't work, because
144        # it's a context manager.
145        assert TestRunner.shared_self is None
146        TestRunner.shared_self = self
147
148        with Pool(jobs) as p:
149            results = p.starmap(self.proc_run_test,
150                                zip(tests, [test_field_width] * len(tests)))
151
152        TestRunner.shared_self = None
153
154        return results
155
156    def __init__(self, env: TestEnv, tap: bool = False,
157                 color: str = 'auto') -> None:
158        self.env = env
159        self.tap = tap
160        self.last_elapsed = LastElapsedTime('.last-elapsed-cache', env)
161
162        assert color in ('auto', 'on', 'off')
163        self.color = (color == 'on') or (color == 'auto' and
164                                         sys.stdout.isatty())
165
166        self._stack: contextlib.ExitStack
167
168    def __enter__(self) -> 'TestRunner':
169        self._stack = contextlib.ExitStack()
170        self._stack.enter_context(self.env)
171        self._stack.enter_context(self.last_elapsed)
172        self._stack.enter_context(savetty())
173        return self
174
175    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
176        self._stack.close()
177
178    def test_print_one_line(self, test: str,
179                            test_field_width: int,
180                            starttime: str,
181                            endtime: Optional[str] = None, status: str = '...',
182                            lasttime: Optional[float] = None,
183                            thistime: Optional[float] = None,
184                            description: str = '',
185                            end: str = '\n') -> None:
186        """ Print short test info before/after test run """
187        test = os.path.basename(test)
188
189        if test_field_width is None:
190            test_field_width = 8
191
192        if self.tap:
193            if status == 'pass':
194                print(f'ok {self.env.imgfmt} {test}')
195            elif status == 'fail':
196                print(f'not ok {self.env.imgfmt} {test}')
197            elif status == 'not run':
198                print(f'ok {self.env.imgfmt} {test} # SKIP')
199            return
200
201        if lasttime:
202            lasttime_s = f' (last: {lasttime:.1f}s)'
203        else:
204            lasttime_s = ''
205        if thistime:
206            thistime_s = f'{thistime:.1f}s'
207        else:
208            thistime_s = '...'
209
210        if endtime:
211            endtime = f'[{endtime}]'
212        else:
213            endtime = ''
214
215        if self.color:
216            if status == 'pass':
217                col = '\033[32m'
218            elif status == 'fail':
219                col = '\033[1m\033[31m'
220            elif status == 'not run':
221                col = '\033[33m'
222            else:
223                col = ''
224
225            col_end = '\033[0m'
226        else:
227            col = ''
228            col_end = ''
229
230        print(f'{test:{test_field_width}} {col}{status:10}{col_end} '
231              f'[{starttime}] {endtime:13}{thistime_s:5} {lasttime_s:14} '
232              f'{description}', end=end)
233
234    def find_reference(self, test: str) -> str:
235        if self.env.cachemode == 'none':
236            ref = f'{test}.out.nocache'
237            if os.path.isfile(ref):
238                return ref
239
240        ref = f'{test}.out.{self.env.imgfmt}'
241        if os.path.isfile(ref):
242            return ref
243
244        ref = f'{test}.{self.env.qemu_default_machine}.out'
245        if os.path.isfile(ref):
246            return ref
247
248        return f'{test}.out'
249
250    def do_run_test(self, test: str, mp: bool) -> TestResult:
251        """
252        Run one test
253
254        :param test: test file path
255        :param mp: if true, we are in a multiprocessing environment, use
256                   personal subdirectories for test run
257
258        Note: this method may be called from subprocess, so it does not
259        change ``self`` object in any way!
260        """
261
262        f_test = Path(test)
263        f_reference = Path(self.find_reference(test))
264
265        if not f_test.exists():
266            return TestResult(status='fail',
267                              description=f'No such test file: {f_test}')
268
269        if not os.access(str(f_test), os.X_OK):
270            sys.exit(f'Not executable: {f_test}')
271
272        if not f_reference.exists():
273            return TestResult(status='not run',
274                              description='No qualified output '
275                                          f'(expected {f_reference})')
276
277        args = [str(f_test.resolve())]
278        env = self.env.prepare_subprocess(args)
279        if mp:
280            # Split test directories, so that tests running in parallel don't
281            # break each other.
282            for d in ['TEST_DIR', 'SOCK_DIR']:
283                env[d] = os.path.join(env[d], f_test.name)
284                Path(env[d]).mkdir(parents=True, exist_ok=True)
285
286        test_dir = env['TEST_DIR']
287        f_bad = Path(test_dir, f_test.name + '.out.bad')
288        f_notrun = Path(test_dir, f_test.name + '.notrun')
289        f_casenotrun = Path(test_dir, f_test.name + '.casenotrun')
290
291        for p in (f_notrun, f_casenotrun):
292            silent_unlink(p)
293
294        t0 = time.time()
295        with f_bad.open('w', encoding="utf-8") as f:
296            with subprocess.Popen(args, cwd=str(f_test.parent), env=env,
297                                  stdout=f, stderr=subprocess.STDOUT) as proc:
298                try:
299                    proc.wait()
300                except KeyboardInterrupt:
301                    proc.terminate()
302                    proc.wait()
303                    return TestResult(status='not run',
304                                      description='Interrupted by user',
305                                      interrupted=True)
306                ret = proc.returncode
307
308        elapsed = round(time.time() - t0, 1)
309
310        if ret != 0:
311            return TestResult(status='fail', elapsed=elapsed,
312                              description=f'failed, exit status {ret}',
313                              diff=file_diff(str(f_reference), str(f_bad)))
314
315        if f_notrun.exists():
316            return TestResult(
317                status='not run',
318                description=f_notrun.read_text(encoding='utf-8').strip())
319
320        casenotrun = ''
321        if f_casenotrun.exists():
322            casenotrun = f_casenotrun.read_text(encoding='utf-8')
323
324        diff = file_diff(str(f_reference), str(f_bad))
325        if diff:
326            if os.environ.get("QEMU_IOTESTS_REGEN", None) is not None:
327                shutil.copyfile(str(f_bad), str(f_reference))
328                print("########################################")
329                print("#####    REFERENCE FILE UPDATED    #####")
330                print("########################################")
331            return TestResult(status='fail', elapsed=elapsed,
332                              description=f'output mismatch (see {f_bad})',
333                              diff=diff, casenotrun=casenotrun)
334        else:
335            f_bad.unlink()
336            return TestResult(status='pass', elapsed=elapsed,
337                              casenotrun=casenotrun)
338
339    def run_test(self, test: str,
340                 test_field_width: int,
341                 mp: bool = False) -> TestResult:
342        """
343        Run one test and print short status
344
345        :param test: test file path
346        :param test_field_width: width for first field of status format
347        :param mp: if true, we are in a multiprocessing environment, don't try
348                   to rewrite things in stdout
349
350        Note: this method may be called from subprocess, so it does not
351        change ``self`` object in any way!
352        """
353
354        last_el = self.last_elapsed.get(test)
355        start = datetime.datetime.now().strftime('%H:%M:%S')
356
357        if not self.tap:
358            self.test_print_one_line(test=test,
359                                     test_field_width=test_field_width,
360                                     status = 'started' if mp else '...',
361                                     starttime=start,
362                                     lasttime=last_el,
363                                     end = '\n' if mp else '\r')
364        else:
365            testname = os.path.basename(test)
366            print(f'# running {self.env.imgfmt} {testname}')
367
368        res = self.do_run_test(test, mp)
369
370        end = datetime.datetime.now().strftime('%H:%M:%S')
371        self.test_print_one_line(test=test,
372                                 test_field_width=test_field_width,
373                                 status=res.status,
374                                 starttime=start, endtime=end,
375                                 lasttime=last_el, thistime=res.elapsed,
376                                 description=res.description)
377
378        if res.casenotrun:
379            if self.tap:
380                print('#' + res.casenotrun.replace('\n', '\n#'))
381            else:
382                print(res.casenotrun)
383
384        sys.stdout.flush()
385        return res
386
387    def run_tests(self, tests: List[str], jobs: int = 1) -> bool:
388        n_run = 0
389        failed = []
390        notrun = []
391        casenotrun = []
392
393        if self.tap:
394            self.env.print_env('# ')
395            print('1..%d' % len(tests))
396        else:
397            self.env.print_env()
398
399        test_field_width = max(len(os.path.basename(t)) for t in tests) + 2
400
401        if jobs > 1:
402            results = self.run_tests_pool(tests, test_field_width, jobs)
403
404        for i, t in enumerate(tests):
405            name = os.path.basename(t)
406
407            if jobs > 1:
408                res = results[i]
409            else:
410                res = self.run_test(t, test_field_width)
411
412            assert res.status in ('pass', 'fail', 'not run')
413
414            if res.casenotrun:
415                casenotrun.append(t)
416
417            if res.status != 'not run':
418                n_run += 1
419
420            if res.status == 'fail':
421                failed.append(name)
422                if res.diff:
423                    if self.tap:
424                        print('\n'.join(res.diff), file=sys.stderr)
425                    else:
426                        print('\n'.join(res.diff))
427            elif res.status == 'not run':
428                notrun.append(name)
429            elif res.status == 'pass':
430                assert res.elapsed is not None
431                self.last_elapsed.update(t, res.elapsed)
432
433            sys.stdout.flush()
434            if res.interrupted:
435                break
436
437        if not self.tap:
438            if notrun:
439                print('Not run:', ' '.join(notrun))
440
441            if casenotrun:
442                print('Some cases not run in:', ' '.join(casenotrun))
443
444            if failed:
445                print('Failures:', ' '.join(failed))
446                print(f'Failed {len(failed)} of {n_run} iotests')
447            else:
448                print(f'Passed all {n_run} iotests')
449        return not failed
450