1# Class for actually running tests.
2#
3# Copyright (c) 2020-2021 Virtuozzo International GmbH
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19import os
20from pathlib import Path
21import datetime
22import time
23import difflib
24import subprocess
25import contextlib
26import json
27import termios
28import sys
29from contextlib import contextmanager
30from typing import List, Optional, Iterator, Any, Sequence, Dict, \
31        ContextManager
32
33from testenv import TestEnv
34
35
36def silent_unlink(path: Path) -> None:
37    try:
38        path.unlink()
39    except OSError:
40        pass
41
42
43def file_diff(file1: str, file2: str) -> List[str]:
44    with open(file1, encoding="utf-8") as f1, \
45         open(file2, encoding="utf-8") as f2:
46        # We want to ignore spaces at line ends. There are a lot of mess about
47        # it in iotests.
48        # TODO: fix all tests to not produce extra spaces, fix all .out files
49        # and use strict diff here!
50        seq1 = [line.rstrip() for line in f1]
51        seq2 = [line.rstrip() for line in f2]
52        res = [line.rstrip()
53               for line in difflib.unified_diff(seq1, seq2, file1, file2)]
54        return res
55
56
57# We want to save current tty settings during test run,
58# since an aborting qemu call may leave things screwed up.
59@contextmanager
60def savetty() -> Iterator[None]:
61    isterm = sys.stdin.isatty()
62    if isterm:
63        fd = sys.stdin.fileno()
64        attr = termios.tcgetattr(fd)
65
66    try:
67        yield
68    finally:
69        if isterm:
70            termios.tcsetattr(fd, termios.TCSADRAIN, attr)
71
72
73class LastElapsedTime(ContextManager['LastElapsedTime']):
74    """ Cache for elapsed time for tests, to show it during new test run
75
76    It is safe to use get() at any time.  To use update(), you must either
77    use it inside with-block or use save() after update().
78    """
79    def __init__(self, cache_file: str, env: TestEnv) -> None:
80        self.env = env
81        self.cache_file = cache_file
82        self.cache: Dict[str, Dict[str, Dict[str, float]]]
83
84        try:
85            with open(cache_file, encoding="utf-8") as f:
86                self.cache = json.load(f)
87        except (OSError, ValueError):
88            self.cache = {}
89
90    def get(self, test: str,
91            default: Optional[float] = None) -> Optional[float]:
92        if test not in self.cache:
93            return default
94
95        if self.env.imgproto not in self.cache[test]:
96            return default
97
98        return self.cache[test][self.env.imgproto].get(self.env.imgfmt,
99                                                       default)
100
101    def update(self, test: str, elapsed: float) -> None:
102        d = self.cache.setdefault(test, {})
103        d.setdefault(self.env.imgproto, {})[self.env.imgfmt] = elapsed
104
105    def save(self) -> None:
106        with open(self.cache_file, 'w', encoding="utf-8") as f:
107            json.dump(self.cache, f)
108
109    def __enter__(self) -> 'LastElapsedTime':
110        return self
111
112    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
113        self.save()
114
115
116class TestResult:
117    def __init__(self, status: str, description: str = '',
118                 elapsed: Optional[float] = None, diff: Sequence[str] = (),
119                 casenotrun: str = '', interrupted: bool = False) -> None:
120        self.status = status
121        self.description = description
122        self.elapsed = elapsed
123        self.diff = diff
124        self.casenotrun = casenotrun
125        self.interrupted = interrupted
126
127
128class TestRunner(ContextManager['TestRunner']):
129    def __init__(self, env: TestEnv, makecheck: bool = False,
130                 color: str = 'auto') -> None:
131        self.env = env
132        self.makecheck = makecheck
133        self.last_elapsed = LastElapsedTime('.last-elapsed-cache', env)
134
135        assert color in ('auto', 'on', 'off')
136        self.color = (color == 'on') or (color == 'auto' and
137                                         sys.stdout.isatty())
138
139        self._stack: contextlib.ExitStack
140
141    def __enter__(self) -> 'TestRunner':
142        self._stack = contextlib.ExitStack()
143        self._stack.enter_context(self.env)
144        self._stack.enter_context(self.last_elapsed)
145        self._stack.enter_context(savetty())
146        return self
147
148    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
149        self._stack.close()
150
151    def test_print_one_line(self, test: str, starttime: str,
152                            endtime: Optional[str] = None, status: str = '...',
153                            lasttime: Optional[float] = None,
154                            thistime: Optional[float] = None,
155                            description: str = '',
156                            test_field_width: Optional[int] = None,
157                            end: str = '\n') -> None:
158        """ Print short test info before/after test run """
159        test = os.path.basename(test)
160
161        if test_field_width is None:
162            test_field_width = 8
163
164        if self.makecheck and status != '...':
165            if status and status != 'pass':
166                status = f' [{status}]'
167            else:
168                status = ''
169
170            print(f'  TEST   iotest-{self.env.imgfmt}: {test}{status}')
171            return
172
173        if lasttime:
174            lasttime_s = f' (last: {lasttime:.1f}s)'
175        else:
176            lasttime_s = ''
177        if thistime:
178            thistime_s = f'{thistime:.1f}s'
179        else:
180            thistime_s = '...'
181
182        if endtime:
183            endtime = f'[{endtime}]'
184        else:
185            endtime = ''
186
187        if self.color:
188            if status == 'pass':
189                col = '\033[32m'
190            elif status == 'fail':
191                col = '\033[1m\033[31m'
192            elif status == 'not run':
193                col = '\033[33m'
194            else:
195                col = ''
196
197            col_end = '\033[0m'
198        else:
199            col = ''
200            col_end = ''
201
202        print(f'{test:{test_field_width}} {col}{status:10}{col_end} '
203              f'[{starttime}] {endtime:13}{thistime_s:5} {lasttime_s:14} '
204              f'{description}', end=end)
205
206    def find_reference(self, test: str) -> str:
207        if self.env.cachemode == 'none':
208            ref = f'{test}.out.nocache'
209            if os.path.isfile(ref):
210                return ref
211
212        ref = f'{test}.out.{self.env.imgfmt}'
213        if os.path.isfile(ref):
214            return ref
215
216        ref = f'{test}.{self.env.qemu_default_machine}.out'
217        if os.path.isfile(ref):
218            return ref
219
220        return f'{test}.out'
221
222    def do_run_test(self, test: str) -> TestResult:
223        f_test = Path(test)
224        f_bad = Path(f_test.name + '.out.bad')
225        f_notrun = Path(f_test.name + '.notrun')
226        f_casenotrun = Path(f_test.name + '.casenotrun')
227        f_reference = Path(self.find_reference(test))
228
229        if not f_test.exists():
230            return TestResult(status='fail',
231                              description=f'No such test file: {f_test}')
232
233        if not os.access(str(f_test), os.X_OK):
234            sys.exit(f'Not executable: {f_test}')
235
236        if not f_reference.exists():
237            return TestResult(status='not run',
238                              description='No qualified output '
239                                          f'(expected {f_reference})')
240
241        for p in (f_bad, f_notrun, f_casenotrun):
242            silent_unlink(p)
243
244        args = [str(f_test.resolve())]
245        env = self.env.prepare_subprocess(args)
246
247        t0 = time.time()
248        with f_bad.open('w', encoding="utf-8") as f:
249            with subprocess.Popen(args, cwd=str(f_test.parent), env=env,
250                                  stdout=f, stderr=subprocess.STDOUT) as proc:
251                try:
252                    proc.wait()
253                except KeyboardInterrupt:
254                    proc.terminate()
255                    proc.wait()
256                    return TestResult(status='not run',
257                                      description='Interrupted by user',
258                                      interrupted=True)
259                ret = proc.returncode
260
261        elapsed = round(time.time() - t0, 1)
262
263        if ret != 0:
264            return TestResult(status='fail', elapsed=elapsed,
265                              description=f'failed, exit status {ret}',
266                              diff=file_diff(str(f_reference), str(f_bad)))
267
268        if f_notrun.exists():
269            return TestResult(
270                status='not run',
271                description=f_notrun.read_text(encoding='utf-8').strip())
272
273        casenotrun = ''
274        if f_casenotrun.exists():
275            casenotrun = f_casenotrun.read_text(encoding='utf-8')
276
277        diff = file_diff(str(f_reference), str(f_bad))
278        if diff:
279            return TestResult(status='fail', elapsed=elapsed,
280                              description=f'output mismatch (see {f_bad})',
281                              diff=diff, casenotrun=casenotrun)
282        else:
283            f_bad.unlink()
284            self.last_elapsed.update(test, elapsed)
285            return TestResult(status='pass', elapsed=elapsed,
286                              casenotrun=casenotrun)
287
288    def run_test(self, test: str,
289                 test_field_width: Optional[int] = None) -> TestResult:
290        last_el = self.last_elapsed.get(test)
291        start = datetime.datetime.now().strftime('%H:%M:%S')
292
293        if not self.makecheck:
294            self.test_print_one_line(test=test, starttime=start,
295                                     lasttime=last_el, end='\r',
296                                     test_field_width=test_field_width)
297
298        res = self.do_run_test(test)
299
300        end = datetime.datetime.now().strftime('%H:%M:%S')
301        self.test_print_one_line(test=test, status=res.status,
302                                 starttime=start, endtime=end,
303                                 lasttime=last_el, thistime=res.elapsed,
304                                 description=res.description,
305                                 test_field_width=test_field_width)
306
307        if res.casenotrun:
308            print(res.casenotrun)
309
310        return res
311
312    def run_tests(self, tests: List[str]) -> bool:
313        n_run = 0
314        failed = []
315        notrun = []
316        casenotrun = []
317
318        if not self.makecheck:
319            self.env.print_env()
320
321        test_field_width = max(len(os.path.basename(t)) for t in tests) + 2
322
323        for t in tests:
324            name = os.path.basename(t)
325            res = self.run_test(t, test_field_width=test_field_width)
326
327            assert res.status in ('pass', 'fail', 'not run')
328
329            if res.casenotrun:
330                casenotrun.append(t)
331
332            if res.status != 'not run':
333                n_run += 1
334
335            if res.status == 'fail':
336                failed.append(name)
337                if self.makecheck:
338                    self.env.print_env()
339                if res.diff:
340                    print('\n'.join(res.diff))
341            elif res.status == 'not run':
342                notrun.append(name)
343
344            if res.interrupted:
345                break
346
347        if notrun:
348            print('Not run:', ' '.join(notrun))
349
350        if casenotrun:
351            print('Some cases not run in:', ' '.join(casenotrun))
352
353        if failed:
354            print('Failures:', ' '.join(failed))
355            print(f'Failed {len(failed)} of {n_run} iotests')
356            return False
357        else:
358            print(f'Passed all {n_run} iotests')
359            return True
360