1# Class for actually running tests.
2#
3# Copyright (c) 2020-2021 Virtuozzo International GmbH
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19import os
20from pathlib import Path
21import datetime
22import time
23import difflib
24import subprocess
25import contextlib
26import json
27import termios
28import sys
29from multiprocessing import Pool
30from contextlib import contextmanager
31from typing import List, Optional, Iterator, Any, Sequence, Dict, \
32        ContextManager
33
34from testenv import TestEnv
35
36
37def silent_unlink(path: Path) -> None:
38    try:
39        path.unlink()
40    except OSError:
41        pass
42
43
44def file_diff(file1: str, file2: str) -> List[str]:
45    with open(file1, encoding="utf-8") as f1, \
46         open(file2, encoding="utf-8") as f2:
47        # We want to ignore spaces at line ends. There are a lot of mess about
48        # it in iotests.
49        # TODO: fix all tests to not produce extra spaces, fix all .out files
50        # and use strict diff here!
51        seq1 = [line.rstrip() for line in f1]
52        seq2 = [line.rstrip() for line in f2]
53        res = [line.rstrip()
54               for line in difflib.unified_diff(seq1, seq2, file1, file2)]
55        return res
56
57
58# We want to save current tty settings during test run,
59# since an aborting qemu call may leave things screwed up.
60@contextmanager
61def savetty() -> Iterator[None]:
62    isterm = sys.stdin.isatty()
63    if isterm:
64        fd = sys.stdin.fileno()
65        attr = termios.tcgetattr(fd)
66
67    try:
68        yield
69    finally:
70        if isterm:
71            termios.tcsetattr(fd, termios.TCSADRAIN, attr)
72
73
74class LastElapsedTime(ContextManager['LastElapsedTime']):
75    """ Cache for elapsed time for tests, to show it during new test run
76
77    It is safe to use get() at any time.  To use update(), you must either
78    use it inside with-block or use save() after update().
79    """
80    def __init__(self, cache_file: str, env: TestEnv) -> None:
81        self.env = env
82        self.cache_file = cache_file
83        self.cache: Dict[str, Dict[str, Dict[str, float]]]
84
85        try:
86            with open(cache_file, encoding="utf-8") as f:
87                self.cache = json.load(f)
88        except (OSError, ValueError):
89            self.cache = {}
90
91    def get(self, test: str,
92            default: Optional[float] = None) -> Optional[float]:
93        if test not in self.cache:
94            return default
95
96        if self.env.imgproto not in self.cache[test]:
97            return default
98
99        return self.cache[test][self.env.imgproto].get(self.env.imgfmt,
100                                                       default)
101
102    def update(self, test: str, elapsed: float) -> None:
103        d = self.cache.setdefault(test, {})
104        d.setdefault(self.env.imgproto, {})[self.env.imgfmt] = elapsed
105
106    def save(self) -> None:
107        with open(self.cache_file, 'w', encoding="utf-8") as f:
108            json.dump(self.cache, f)
109
110    def __enter__(self) -> 'LastElapsedTime':
111        return self
112
113    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
114        self.save()
115
116
117class TestResult:
118    def __init__(self, status: str, description: str = '',
119                 elapsed: Optional[float] = None, diff: Sequence[str] = (),
120                 casenotrun: str = '', interrupted: bool = False) -> None:
121        self.status = status
122        self.description = description
123        self.elapsed = elapsed
124        self.diff = diff
125        self.casenotrun = casenotrun
126        self.interrupted = interrupted
127
128
129class TestRunner(ContextManager['TestRunner']):
130    shared_self = None
131
132    @staticmethod
133    def proc_run_test(test: str, test_field_width: int) -> TestResult:
134        # We are in a subprocess, we can't change the runner object!
135        runner = TestRunner.shared_self
136        assert runner is not None
137        return runner.run_test(test, test_field_width, mp=True)
138
139    def run_tests_pool(self, tests: List[str],
140                       test_field_width: int, jobs: int) -> List[TestResult]:
141
142        # passing self directly to Pool.starmap() just doesn't work, because
143        # it's a context manager.
144        assert TestRunner.shared_self is None
145        TestRunner.shared_self = self
146
147        with Pool(jobs) as p:
148            results = p.starmap(self.proc_run_test,
149                                zip(tests, [test_field_width] * len(tests)))
150
151        TestRunner.shared_self = None
152
153        return results
154
155    def __init__(self, env: TestEnv, makecheck: bool = False,
156                 color: str = 'auto') -> None:
157        self.env = env
158        self.makecheck = makecheck
159        self.last_elapsed = LastElapsedTime('.last-elapsed-cache', env)
160
161        assert color in ('auto', 'on', 'off')
162        self.color = (color == 'on') or (color == 'auto' and
163                                         sys.stdout.isatty())
164
165        self._stack: contextlib.ExitStack
166
167    def __enter__(self) -> 'TestRunner':
168        self._stack = contextlib.ExitStack()
169        self._stack.enter_context(self.env)
170        self._stack.enter_context(self.last_elapsed)
171        self._stack.enter_context(savetty())
172        return self
173
174    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
175        self._stack.close()
176
177    def test_print_one_line(self, test: str,
178                            test_field_width: int,
179                            starttime: str,
180                            endtime: Optional[str] = None, status: str = '...',
181                            lasttime: Optional[float] = None,
182                            thistime: Optional[float] = None,
183                            description: str = '',
184                            end: str = '\n') -> None:
185        """ Print short test info before/after test run """
186        test = os.path.basename(test)
187
188        if self.makecheck and status != '...':
189            if status and status != 'pass':
190                status = f' [{status}]'
191            else:
192                status = ''
193
194            print(f'  TEST   iotest-{self.env.imgfmt}: {test}{status}')
195            return
196
197        if lasttime:
198            lasttime_s = f' (last: {lasttime:.1f}s)'
199        else:
200            lasttime_s = ''
201        if thistime:
202            thistime_s = f'{thistime:.1f}s'
203        else:
204            thistime_s = '...'
205
206        if endtime:
207            endtime = f'[{endtime}]'
208        else:
209            endtime = ''
210
211        if self.color:
212            if status == 'pass':
213                col = '\033[32m'
214            elif status == 'fail':
215                col = '\033[1m\033[31m'
216            elif status == 'not run':
217                col = '\033[33m'
218            else:
219                col = ''
220
221            col_end = '\033[0m'
222        else:
223            col = ''
224            col_end = ''
225
226        print(f'{test:{test_field_width}} {col}{status:10}{col_end} '
227              f'[{starttime}] {endtime:13}{thistime_s:5} {lasttime_s:14} '
228              f'{description}', end=end)
229
230    def find_reference(self, test: str) -> str:
231        if self.env.cachemode == 'none':
232            ref = f'{test}.out.nocache'
233            if os.path.isfile(ref):
234                return ref
235
236        ref = f'{test}.out.{self.env.imgfmt}'
237        if os.path.isfile(ref):
238            return ref
239
240        ref = f'{test}.{self.env.qemu_default_machine}.out'
241        if os.path.isfile(ref):
242            return ref
243
244        return f'{test}.out'
245
246    def do_run_test(self, test: str, mp: bool) -> TestResult:
247        """
248        Run one test
249
250        :param test: test file path
251        :param mp: if true, we are in a multiprocessing environment, use
252                   personal subdirectories for test run
253
254        Note: this method may be called from subprocess, so it does not
255        change ``self`` object in any way!
256        """
257
258        f_test = Path(test)
259        f_bad = Path(f_test.name + '.out.bad')
260        f_notrun = Path(f_test.name + '.notrun')
261        f_casenotrun = Path(f_test.name + '.casenotrun')
262        f_reference = Path(self.find_reference(test))
263
264        if not f_test.exists():
265            return TestResult(status='fail',
266                              description=f'No such test file: {f_test}')
267
268        if not os.access(str(f_test), os.X_OK):
269            sys.exit(f'Not executable: {f_test}')
270
271        if not f_reference.exists():
272            return TestResult(status='not run',
273                              description='No qualified output '
274                                          f'(expected {f_reference})')
275
276        for p in (f_bad, f_notrun, f_casenotrun):
277            silent_unlink(p)
278
279        args = [str(f_test.resolve())]
280        env = self.env.prepare_subprocess(args)
281        if mp:
282            # Split test directories, so that tests running in parallel don't
283            # break each other.
284            for d in ['TEST_DIR', 'SOCK_DIR']:
285                env[d] = os.path.join(env[d], f_test.name)
286                Path(env[d]).mkdir(parents=True, exist_ok=True)
287
288        t0 = time.time()
289        with f_bad.open('w', encoding="utf-8") as f:
290            with subprocess.Popen(args, cwd=str(f_test.parent), env=env,
291                                  stdout=f, stderr=subprocess.STDOUT) as proc:
292                try:
293                    proc.wait()
294                except KeyboardInterrupt:
295                    proc.terminate()
296                    proc.wait()
297                    return TestResult(status='not run',
298                                      description='Interrupted by user',
299                                      interrupted=True)
300                ret = proc.returncode
301
302        elapsed = round(time.time() - t0, 1)
303
304        if ret != 0:
305            return TestResult(status='fail', elapsed=elapsed,
306                              description=f'failed, exit status {ret}',
307                              diff=file_diff(str(f_reference), str(f_bad)))
308
309        if f_notrun.exists():
310            return TestResult(
311                status='not run',
312                description=f_notrun.read_text(encoding='utf-8').strip())
313
314        casenotrun = ''
315        if f_casenotrun.exists():
316            casenotrun = f_casenotrun.read_text(encoding='utf-8')
317
318        diff = file_diff(str(f_reference), str(f_bad))
319        if diff:
320            return TestResult(status='fail', elapsed=elapsed,
321                              description=f'output mismatch (see {f_bad})',
322                              diff=diff, casenotrun=casenotrun)
323        else:
324            f_bad.unlink()
325            return TestResult(status='pass', elapsed=elapsed,
326                              casenotrun=casenotrun)
327
328    def run_test(self, test: str,
329                 test_field_width: int,
330                 mp: bool = False) -> TestResult:
331        """
332        Run one test and print short status
333
334        :param test: test file path
335        :param test_field_width: width for first field of status format
336        :param mp: if true, we are in a multiprocessing environment, don't try
337                   to rewrite things in stdout
338
339        Note: this method may be called from subprocess, so it does not
340        change ``self`` object in any way!
341        """
342
343        last_el = self.last_elapsed.get(test)
344        start = datetime.datetime.now().strftime('%H:%M:%S')
345
346        if not self.makecheck:
347            self.test_print_one_line(test=test,
348                                     test_field_width=test_field_width,
349                                     status = 'started' if mp else '...',
350                                     starttime=start,
351                                     lasttime=last_el,
352                                     end = '\n' if mp else '\r')
353
354        res = self.do_run_test(test, mp)
355
356        end = datetime.datetime.now().strftime('%H:%M:%S')
357        self.test_print_one_line(test=test,
358                                 test_field_width=test_field_width,
359                                 status=res.status,
360                                 starttime=start, endtime=end,
361                                 lasttime=last_el, thistime=res.elapsed,
362                                 description=res.description)
363
364        if res.casenotrun:
365            print(res.casenotrun)
366
367        return res
368
369    def run_tests(self, tests: List[str], jobs: int = 1) -> bool:
370        n_run = 0
371        failed = []
372        notrun = []
373        casenotrun = []
374
375        if not self.makecheck:
376            self.env.print_env()
377
378        test_field_width = max(len(os.path.basename(t)) for t in tests) + 2
379
380        if jobs > 1:
381            results = self.run_tests_pool(tests, test_field_width, jobs)
382
383        for i, t in enumerate(tests):
384            name = os.path.basename(t)
385
386            if jobs > 1:
387                res = results[i]
388            else:
389                res = self.run_test(t, test_field_width)
390
391            assert res.status in ('pass', 'fail', 'not run')
392
393            if res.casenotrun:
394                casenotrun.append(t)
395
396            if res.status != 'not run':
397                n_run += 1
398
399            if res.status == 'fail':
400                failed.append(name)
401                if self.makecheck:
402                    self.env.print_env()
403                if res.diff:
404                    print('\n'.join(res.diff))
405            elif res.status == 'not run':
406                notrun.append(name)
407            elif res.status == 'pass':
408                assert res.elapsed is not None
409                self.last_elapsed.update(t, res.elapsed)
410
411            sys.stdout.flush()
412            if res.interrupted:
413                break
414
415        if notrun:
416            print('Not run:', ' '.join(notrun))
417
418        if casenotrun:
419            print('Some cases not run in:', ' '.join(casenotrun))
420
421        if failed:
422            print('Failures:', ' '.join(failed))
423            print(f'Failed {len(failed)} of {n_run} iotests')
424            return False
425        else:
426            print(f'Passed all {n_run} iotests')
427            return True
428