1# Class for actually running tests.
2#
3# Copyright (c) 2020-2021 Virtuozzo International GmbH
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19import os
20from pathlib import Path
21import datetime
22import time
23import difflib
24import subprocess
25import contextlib
26import json
27import termios
28import sys
29from contextlib import contextmanager
30from typing import List, Optional, Iterator, Any, Sequence, Dict, \
31        ContextManager
32
33from testenv import TestEnv
34
35
36def silent_unlink(path: Path) -> None:
37    try:
38        path.unlink()
39    except OSError:
40        pass
41
42
43def file_diff(file1: str, file2: str) -> List[str]:
44    with open(file1, encoding="utf-8") as f1, \
45         open(file2, encoding="utf-8") as f2:
46        # We want to ignore spaces at line ends. There are a lot of mess about
47        # it in iotests.
48        # TODO: fix all tests to not produce extra spaces, fix all .out files
49        # and use strict diff here!
50        seq1 = [line.rstrip() for line in f1]
51        seq2 = [line.rstrip() for line in f2]
52        res = [line.rstrip()
53               for line in difflib.unified_diff(seq1, seq2, file1, file2)]
54        return res
55
56
57# We want to save current tty settings during test run,
58# since an aborting qemu call may leave things screwed up.
59@contextmanager
60def savetty() -> Iterator[None]:
61    isterm = sys.stdin.isatty()
62    if isterm:
63        fd = sys.stdin.fileno()
64        attr = termios.tcgetattr(fd)
65
66    try:
67        yield
68    finally:
69        if isterm:
70            termios.tcsetattr(fd, termios.TCSADRAIN, attr)
71
72
73class LastElapsedTime(ContextManager['LastElapsedTime']):
74    """ Cache for elapsed time for tests, to show it during new test run
75
76    It is safe to use get() at any time.  To use update(), you must either
77    use it inside with-block or use save() after update().
78    """
79    def __init__(self, cache_file: str, env: TestEnv) -> None:
80        self.env = env
81        self.cache_file = cache_file
82        self.cache: Dict[str, Dict[str, Dict[str, float]]]
83
84        try:
85            with open(cache_file, encoding="utf-8") as f:
86                self.cache = json.load(f)
87        except (OSError, ValueError):
88            self.cache = {}
89
90    def get(self, test: str,
91            default: Optional[float] = None) -> Optional[float]:
92        if test not in self.cache:
93            return default
94
95        if self.env.imgproto not in self.cache[test]:
96            return default
97
98        return self.cache[test][self.env.imgproto].get(self.env.imgfmt,
99                                                       default)
100
101    def update(self, test: str, elapsed: float) -> None:
102        d = self.cache.setdefault(test, {})
103        d.setdefault(self.env.imgproto, {})[self.env.imgfmt] = elapsed
104
105    def save(self) -> None:
106        with open(self.cache_file, 'w', encoding="utf-8") as f:
107            json.dump(self.cache, f)
108
109    def __enter__(self) -> 'LastElapsedTime':
110        return self
111
112    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
113        self.save()
114
115
116class TestResult:
117    def __init__(self, status: str, description: str = '',
118                 elapsed: Optional[float] = None, diff: Sequence[str] = (),
119                 casenotrun: str = '', interrupted: bool = False) -> None:
120        self.status = status
121        self.description = description
122        self.elapsed = elapsed
123        self.diff = diff
124        self.casenotrun = casenotrun
125        self.interrupted = interrupted
126
127
128class TestRunner(ContextManager['TestRunner']):
129    def __init__(self, env: TestEnv, makecheck: bool = False,
130                 color: str = 'auto') -> None:
131        self.env = env
132        self.makecheck = makecheck
133        self.last_elapsed = LastElapsedTime('.last-elapsed-cache', env)
134
135        assert color in ('auto', 'on', 'off')
136        self.color = (color == 'on') or (color == 'auto' and
137                                         sys.stdout.isatty())
138
139        self._stack: contextlib.ExitStack
140
141    def __enter__(self) -> 'TestRunner':
142        self._stack = contextlib.ExitStack()
143        self._stack.enter_context(self.env)
144        self._stack.enter_context(self.last_elapsed)
145        self._stack.enter_context(savetty())
146        return self
147
148    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
149        self._stack.close()
150
151    def test_print_one_line(self, test: str, starttime: str,
152                            endtime: Optional[str] = None, status: str = '...',
153                            lasttime: Optional[float] = None,
154                            thistime: Optional[float] = None,
155                            description: str = '',
156                            test_field_width: Optional[int] = None,
157                            end: str = '\n') -> None:
158        """ Print short test info before/after test run """
159        test = os.path.basename(test)
160
161        if test_field_width is None:
162            test_field_width = 8
163
164        if self.makecheck and status != '...':
165            if status and status != 'pass':
166                status = f' [{status}]'
167            else:
168                status = ''
169
170            print(f'  TEST   iotest-{self.env.imgfmt}: {test}{status}')
171            return
172
173        if lasttime:
174            lasttime_s = f' (last: {lasttime:.1f}s)'
175        else:
176            lasttime_s = ''
177        if thistime:
178            thistime_s = f'{thistime:.1f}s'
179        else:
180            thistime_s = '...'
181
182        if endtime:
183            endtime = f'[{endtime}]'
184        else:
185            endtime = ''
186
187        if self.color:
188            if status == 'pass':
189                col = '\033[32m'
190            elif status == 'fail':
191                col = '\033[1m\033[31m'
192            elif status == 'not run':
193                col = '\033[33m'
194            else:
195                col = ''
196
197            col_end = '\033[0m'
198        else:
199            col = ''
200            col_end = ''
201
202        print(f'{test:{test_field_width}} {col}{status:10}{col_end} '
203              f'[{starttime}] {endtime:13}{thistime_s:5} {lasttime_s:14} '
204              f'{description}', end=end)
205
206    def find_reference(self, test: str) -> str:
207        if self.env.cachemode == 'none':
208            ref = f'{test}.out.nocache'
209            if os.path.isfile(ref):
210                return ref
211
212        ref = f'{test}.out.{self.env.imgfmt}'
213        if os.path.isfile(ref):
214            return ref
215
216        ref = f'{test}.{self.env.qemu_default_machine}.out'
217        if os.path.isfile(ref):
218            return ref
219
220        return f'{test}.out'
221
222    def do_run_test(self, test: str) -> TestResult:
223        f_test = Path(test)
224        f_bad = Path(f_test.name + '.out.bad')
225        f_notrun = Path(f_test.name + '.notrun')
226        f_casenotrun = Path(f_test.name + '.casenotrun')
227        f_reference = Path(self.find_reference(test))
228
229        if not f_test.exists():
230            return TestResult(status='fail',
231                              description=f'No such test file: {f_test}')
232
233        if not os.access(str(f_test), os.X_OK):
234            sys.exit(f'Not executable: {f_test}')
235
236        if not f_reference.exists():
237            return TestResult(status='not run',
238                              description='No qualified output '
239                                          f'(expected {f_reference})')
240
241        for p in (f_bad, f_notrun, f_casenotrun):
242            silent_unlink(p)
243
244        args = [str(f_test.resolve())]
245        env = self.env.prepare_subprocess(args)
246
247        t0 = time.time()
248        with f_bad.open('w', encoding="utf-8") as f:
249            with subprocess.Popen(args, cwd=str(f_test.parent), env=env,
250                                  stdout=f, stderr=subprocess.STDOUT) as proc:
251                try:
252                    proc.wait()
253                except KeyboardInterrupt:
254                    proc.terminate()
255                    proc.wait()
256                    return TestResult(status='not run',
257                                      description='Interrupted by user',
258                                      interrupted=True)
259                ret = proc.returncode
260
261        elapsed = round(time.time() - t0, 1)
262
263        if ret != 0:
264            return TestResult(status='fail', elapsed=elapsed,
265                              description=f'failed, exit status {ret}',
266                              diff=file_diff(str(f_reference), str(f_bad)))
267
268        if f_notrun.exists():
269            return TestResult(status='not run',
270                              description=f_notrun.read_text().strip())
271
272        casenotrun = ''
273        if f_casenotrun.exists():
274            casenotrun = f_casenotrun.read_text()
275
276        diff = file_diff(str(f_reference), str(f_bad))
277        if diff:
278            return TestResult(status='fail', elapsed=elapsed,
279                              description=f'output mismatch (see {f_bad})',
280                              diff=diff, casenotrun=casenotrun)
281        else:
282            f_bad.unlink()
283            self.last_elapsed.update(test, elapsed)
284            return TestResult(status='pass', elapsed=elapsed,
285                              casenotrun=casenotrun)
286
287    def run_test(self, test: str,
288                 test_field_width: Optional[int] = None) -> TestResult:
289        last_el = self.last_elapsed.get(test)
290        start = datetime.datetime.now().strftime('%H:%M:%S')
291
292        if not self.makecheck:
293            self.test_print_one_line(test=test, starttime=start,
294                                     lasttime=last_el, end='\r',
295                                     test_field_width=test_field_width)
296
297        res = self.do_run_test(test)
298
299        end = datetime.datetime.now().strftime('%H:%M:%S')
300        self.test_print_one_line(test=test, status=res.status,
301                                 starttime=start, endtime=end,
302                                 lasttime=last_el, thistime=res.elapsed,
303                                 description=res.description,
304                                 test_field_width=test_field_width)
305
306        if res.casenotrun:
307            print(res.casenotrun)
308
309        return res
310
311    def run_tests(self, tests: List[str]) -> bool:
312        n_run = 0
313        failed = []
314        notrun = []
315        casenotrun = []
316
317        if not self.makecheck:
318            self.env.print_env()
319
320        test_field_width = max(len(os.path.basename(t)) for t in tests) + 2
321
322        for t in tests:
323            name = os.path.basename(t)
324            res = self.run_test(t, test_field_width=test_field_width)
325
326            assert res.status in ('pass', 'fail', 'not run')
327
328            if res.casenotrun:
329                casenotrun.append(t)
330
331            if res.status != 'not run':
332                n_run += 1
333
334            if res.status == 'fail':
335                failed.append(name)
336                if self.makecheck:
337                    self.env.print_env()
338                if res.diff:
339                    print('\n'.join(res.diff))
340            elif res.status == 'not run':
341                notrun.append(name)
342
343            if res.interrupted:
344                break
345
346        if notrun:
347            print('Not run:', ' '.join(notrun))
348
349        if casenotrun:
350            print('Some cases not run in:', ' '.join(casenotrun))
351
352        if failed:
353            print('Failures:', ' '.join(failed))
354            print(f'Failed {len(failed)} of {n_run} iotests')
355            return False
356        else:
357            print(f'Passed all {n_run} iotests')
358            return True
359