1# Class for actually running tests.
2#
3# Copyright (c) 2020-2021 Virtuozzo International GmbH
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19import os
20from pathlib import Path
21import datetime
22import time
23import difflib
24import subprocess
25import contextlib
26import json
27import termios
28import shutil
29import sys
30from multiprocessing import Pool
31from contextlib import contextmanager
32from typing import List, Optional, Iterator, Any, Sequence, Dict, \
33        ContextManager
34
35from testenv import TestEnv
36
37
38def silent_unlink(path: Path) -> None:
39    try:
40        path.unlink()
41    except OSError:
42        pass
43
44
45def file_diff(file1: str, file2: str) -> List[str]:
46    with open(file1, encoding="utf-8") as f1, \
47         open(file2, encoding="utf-8") as f2:
48        # We want to ignore spaces at line ends. There are a lot of mess about
49        # it in iotests.
50        # TODO: fix all tests to not produce extra spaces, fix all .out files
51        # and use strict diff here!
52        seq1 = [line.rstrip() for line in f1]
53        seq2 = [line.rstrip() for line in f2]
54        res = [line.rstrip()
55               for line in difflib.unified_diff(seq1, seq2, file1, file2)]
56        return res
57
58
59# We want to save current tty settings during test run,
60# since an aborting qemu call may leave things screwed up.
61@contextmanager
62def savetty() -> Iterator[None]:
63    isterm = sys.stdin.isatty()
64    if isterm:
65        fd = sys.stdin.fileno()
66        attr = termios.tcgetattr(fd)
67
68    try:
69        yield
70    finally:
71        if isterm:
72            termios.tcsetattr(fd, termios.TCSADRAIN, attr)
73
74
75class LastElapsedTime(ContextManager['LastElapsedTime']):
76    """ Cache for elapsed time for tests, to show it during new test run
77
78    It is safe to use get() at any time.  To use update(), you must either
79    use it inside with-block or use save() after update().
80    """
81    def __init__(self, cache_file: str, env: TestEnv) -> None:
82        self.env = env
83        self.cache_file = cache_file
84        self.cache: Dict[str, Dict[str, Dict[str, float]]]
85
86        try:
87            with open(cache_file, encoding="utf-8") as f:
88                self.cache = json.load(f)
89        except (OSError, ValueError):
90            self.cache = {}
91
92    def get(self, test: str,
93            default: Optional[float] = None) -> Optional[float]:
94        if test not in self.cache:
95            return default
96
97        if self.env.imgproto not in self.cache[test]:
98            return default
99
100        return self.cache[test][self.env.imgproto].get(self.env.imgfmt,
101                                                       default)
102
103    def update(self, test: str, elapsed: float) -> None:
104        d = self.cache.setdefault(test, {})
105        d.setdefault(self.env.imgproto, {})[self.env.imgfmt] = elapsed
106
107    def save(self) -> None:
108        with open(self.cache_file, 'w', encoding="utf-8") as f:
109            json.dump(self.cache, f)
110
111    def __enter__(self) -> 'LastElapsedTime':
112        return self
113
114    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
115        self.save()
116
117
118class TestResult:
119    def __init__(self, status: str, description: str = '',
120                 elapsed: Optional[float] = None, diff: Sequence[str] = (),
121                 casenotrun: str = '', interrupted: bool = False) -> None:
122        self.status = status
123        self.description = description
124        self.elapsed = elapsed
125        self.diff = diff
126        self.casenotrun = casenotrun
127        self.interrupted = interrupted
128
129
130class TestRunner(ContextManager['TestRunner']):
131    shared_self = None
132
133    @staticmethod
134    def proc_run_test(test: str, test_field_width: int) -> TestResult:
135        # We are in a subprocess, we can't change the runner object!
136        runner = TestRunner.shared_self
137        assert runner is not None
138        return runner.run_test(test, test_field_width, mp=True)
139
140    def run_tests_pool(self, tests: List[str],
141                       test_field_width: int, jobs: int) -> List[TestResult]:
142
143        # passing self directly to Pool.starmap() just doesn't work, because
144        # it's a context manager.
145        assert TestRunner.shared_self is None
146        TestRunner.shared_self = self
147
148        with Pool(jobs) as p:
149            results = p.starmap(self.proc_run_test,
150                                zip(tests, [test_field_width] * len(tests)))
151
152        TestRunner.shared_self = None
153
154        return results
155
156    def __init__(self, env: TestEnv, tap: bool = False,
157                 color: str = 'auto') -> None:
158        self.env = env
159        self.tap = tap
160        self.last_elapsed = LastElapsedTime('.last-elapsed-cache', env)
161
162        assert color in ('auto', 'on', 'off')
163        self.color = (color == 'on') or (color == 'auto' and
164                                         sys.stdout.isatty())
165
166        self._stack: contextlib.ExitStack
167
168    def __enter__(self) -> 'TestRunner':
169        self._stack = contextlib.ExitStack()
170        self._stack.enter_context(self.env)
171        self._stack.enter_context(self.last_elapsed)
172        self._stack.enter_context(savetty())
173        return self
174
175    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
176        self._stack.close()
177
178    def test_print_one_line(self, test: str,
179                            test_field_width: int,
180                            starttime: str,
181                            endtime: Optional[str] = None, status: str = '...',
182                            lasttime: Optional[float] = None,
183                            thistime: Optional[float] = None,
184                            description: str = '',
185                            end: str = '\n') -> None:
186        """ Print short test info before/after test run """
187        test = os.path.basename(test)
188
189        if test_field_width is None:
190            test_field_width = 8
191
192        if self.tap:
193            if status == 'pass':
194                print(f'ok {self.env.imgfmt} {test}')
195            elif status == 'fail':
196                print(f'not ok {self.env.imgfmt} {test}')
197            elif status == 'not run':
198                print(f'ok {self.env.imgfmt} {test} # SKIP')
199            return
200
201        if lasttime:
202            lasttime_s = f' (last: {lasttime:.1f}s)'
203        else:
204            lasttime_s = ''
205        if thistime:
206            thistime_s = f'{thistime:.1f}s'
207        else:
208            thistime_s = '...'
209
210        if endtime:
211            endtime = f'[{endtime}]'
212        else:
213            endtime = ''
214
215        if self.color:
216            if status == 'pass':
217                col = '\033[32m'
218            elif status == 'fail':
219                col = '\033[1m\033[31m'
220            elif status == 'not run':
221                col = '\033[33m'
222            else:
223                col = ''
224
225            col_end = '\033[0m'
226        else:
227            col = ''
228            col_end = ''
229
230        print(f'{test:{test_field_width}} {col}{status:10}{col_end} '
231              f'[{starttime}] {endtime:13}{thistime_s:5} {lasttime_s:14} '
232              f'{description}', end=end)
233
234    def find_reference(self, test: str) -> str:
235        if self.env.cachemode == 'none':
236            ref = f'{test}.out.nocache'
237            if os.path.isfile(ref):
238                return ref
239
240        ref = f'{test}.out.{self.env.imgfmt}'
241        if os.path.isfile(ref):
242            return ref
243
244        ref = f'{test}.{self.env.qemu_default_machine}.out'
245        if os.path.isfile(ref):
246            return ref
247
248        return f'{test}.out'
249
250    def do_run_test(self, test: str, mp: bool) -> TestResult:
251        """
252        Run one test
253
254        :param test: test file path
255        :param mp: if true, we are in a multiprocessing environment, use
256                   personal subdirectories for test run
257
258        Note: this method may be called from subprocess, so it does not
259        change ``self`` object in any way!
260        """
261
262        f_test = Path(test)
263        f_reference = Path(self.find_reference(test))
264
265        if not f_test.exists():
266            return TestResult(status='fail',
267                              description=f'No such test file: {f_test}')
268
269        if not os.access(str(f_test), os.X_OK):
270            sys.exit(f'Not executable: {f_test}')
271
272        if not f_reference.exists():
273            return TestResult(status='not run',
274                              description='No qualified output '
275                                          f'(expected {f_reference})')
276
277        args = [str(f_test.resolve())]
278        env = self.env.prepare_subprocess(args)
279        if mp:
280            # Split test directories, so that tests running in parallel don't
281            # break each other.
282            for d in ['TEST_DIR', 'SOCK_DIR']:
283                env[d] = os.path.join(env[d], f_test.name)
284                Path(env[d]).mkdir(parents=True, exist_ok=True)
285
286        test_dir = env['TEST_DIR']
287        f_bad = Path(test_dir, f_test.name + '.out.bad')
288        f_notrun = Path(test_dir, f_test.name + '.notrun')
289        f_casenotrun = Path(test_dir, f_test.name + '.casenotrun')
290
291        for p in (f_notrun, f_casenotrun):
292            silent_unlink(p)
293
294        t0 = time.time()
295        with f_bad.open('w', encoding="utf-8") as f:
296            with subprocess.Popen(args, cwd=str(f_test.parent), env=env,
297                                  stdout=f, stderr=subprocess.STDOUT) as proc:
298                try:
299                    proc.wait()
300                except KeyboardInterrupt:
301                    proc.terminate()
302                    proc.wait()
303                    return TestResult(status='not run',
304                                      description='Interrupted by user',
305                                      interrupted=True)
306                ret = proc.returncode
307
308        elapsed = round(time.time() - t0, 1)
309
310        if ret != 0:
311            return TestResult(status='fail', elapsed=elapsed,
312                              description=f'failed, exit status {ret}',
313                              diff=file_diff(str(f_reference), str(f_bad)))
314
315        if f_notrun.exists():
316            return TestResult(
317                status='not run',
318                description=f_notrun.read_text(encoding='utf-8').strip())
319
320        casenotrun = ''
321        if f_casenotrun.exists():
322            casenotrun = f_casenotrun.read_text(encoding='utf-8')
323
324        diff = file_diff(str(f_reference), str(f_bad))
325        if diff:
326            if os.environ.get("QEMU_IOTESTS_REGEN", None) is not None:
327                shutil.copyfile(str(f_bad), str(f_reference))
328                print("########################################")
329                print("#####    REFERENCE FILE UPDATED    #####")
330                print("########################################")
331            return TestResult(status='fail', elapsed=elapsed,
332                              description=f'output mismatch (see {f_bad})',
333                              diff=diff, casenotrun=casenotrun)
334        else:
335            f_bad.unlink()
336            return TestResult(status='pass', elapsed=elapsed,
337                              casenotrun=casenotrun)
338
339    def run_test(self, test: str,
340                 test_field_width: int,
341                 mp: bool = False) -> TestResult:
342        """
343        Run one test and print short status
344
345        :param test: test file path
346        :param test_field_width: width for first field of status format
347        :param mp: if true, we are in a multiprocessing environment, don't try
348                   to rewrite things in stdout
349
350        Note: this method may be called from subprocess, so it does not
351        change ``self`` object in any way!
352        """
353
354        last_el = self.last_elapsed.get(test)
355        start = datetime.datetime.now().strftime('%H:%M:%S')
356
357        if not self.tap:
358            self.test_print_one_line(test=test,
359                                     test_field_width=test_field_width,
360                                     status = 'started' if mp else '...',
361                                     starttime=start,
362                                     lasttime=last_el,
363                                     end = '\n' if mp else '\r')
364
365        res = self.do_run_test(test, mp)
366
367        end = datetime.datetime.now().strftime('%H:%M:%S')
368        self.test_print_one_line(test=test,
369                                 test_field_width=test_field_width,
370                                 status=res.status,
371                                 starttime=start, endtime=end,
372                                 lasttime=last_el, thistime=res.elapsed,
373                                 description=res.description)
374
375        if res.casenotrun:
376            if self.tap:
377                print('#' + res.casenotrun.replace('\n', '\n#'))
378            else:
379                print(res.casenotrun)
380
381        return res
382
383    def run_tests(self, tests: List[str], jobs: int = 1) -> bool:
384        n_run = 0
385        failed = []
386        notrun = []
387        casenotrun = []
388
389        if self.tap:
390            self.env.print_env('# ')
391            print('1..%d' % len(tests))
392        else:
393            self.env.print_env()
394
395        test_field_width = max(len(os.path.basename(t)) for t in tests) + 2
396
397        if jobs > 1:
398            results = self.run_tests_pool(tests, test_field_width, jobs)
399
400        for i, t in enumerate(tests):
401            name = os.path.basename(t)
402
403            if jobs > 1:
404                res = results[i]
405            else:
406                res = self.run_test(t, test_field_width)
407
408            assert res.status in ('pass', 'fail', 'not run')
409
410            if res.casenotrun:
411                casenotrun.append(t)
412
413            if res.status != 'not run':
414                n_run += 1
415
416            if res.status == 'fail':
417                failed.append(name)
418                if res.diff:
419                    if self.tap:
420                        print('\n'.join(res.diff), file=sys.stderr)
421                    else:
422                        print('\n'.join(res.diff))
423            elif res.status == 'not run':
424                notrun.append(name)
425            elif res.status == 'pass':
426                assert res.elapsed is not None
427                self.last_elapsed.update(t, res.elapsed)
428
429            sys.stdout.flush()
430            if res.interrupted:
431                break
432
433        if not self.tap:
434            if notrun:
435                print('Not run:', ' '.join(notrun))
436
437            if casenotrun:
438                print('Some cases not run in:', ' '.join(casenotrun))
439
440            if failed:
441                print('Failures:', ' '.join(failed))
442                print(f'Failed {len(failed)} of {n_run} iotests')
443            else:
444                print(f'Passed all {n_run} iotests')
445        return not failed
446