1#!/usr/bin/env python3
2#
3# SPDX-License-Identifier: GPL-2.0-or-later
4#
5# Modified for use in OE by Richard Purdie, 2018
6#
7# Modified by: Corey Goldberg, 2013
8#   License: GPLv2+
9#
10# Original code from:
11#   Bazaar (bzrlib.tests.__init__.py, v2.6, copied Jun 01 2013)
12#   Copyright (C) 2005-2011 Canonical Ltd
13#   License: GPLv2+
14
15import os
16import sys
17import traceback
18import unittest
19import subprocess
20import testtools
21import threading
22import time
23import io
24import json
25import subunit
26
27from queue import Queue
28from itertools import cycle
29from subunit import ProtocolTestCase, TestProtocolClient
30from subunit.test_results import AutoTimingTestResultDecorator
31from testtools import ThreadsafeForwardingResult, iterate_tests
32from testtools.content import Content
33from testtools.content_type import ContentType
34from oeqa.utils.commands import get_test_layer
35
36import bb.utils
37import oe.path
38
39_all__ = [
40    'ConcurrentTestSuite',
41    'fork_for_tests',
42    'partition_tests',
43]
44
45#
46# Patch the version from testtools to allow access to _test_start and allow
47# computation of timing information and threading progress
48#
49class BBThreadsafeForwardingResult(ThreadsafeForwardingResult):
50
51    def __init__(self, target, semaphore, threadnum, totalinprocess, totaltests):
52        super(BBThreadsafeForwardingResult, self).__init__(target, semaphore)
53        self.threadnum = threadnum
54        self.totalinprocess = totalinprocess
55        self.totaltests = totaltests
56
57    def _add_result_with_semaphore(self, method, test, *args, **kwargs):
58        self.semaphore.acquire()
59        try:
60            if self._test_start:
61                self.result.starttime[test.id()] = self._test_start.timestamp()
62                self.result.threadprogress[self.threadnum].append(test.id())
63                totalprogress = sum(len(x) for x in self.result.threadprogress.values())
64                self.result.progressinfo[test.id()] = "%s: %s/%s %s/%s (%ss) (%s)" % (
65                    self.threadnum,
66                    len(self.result.threadprogress[self.threadnum]),
67                    self.totalinprocess,
68                    totalprogress,
69                    self.totaltests,
70                    "{0:.2f}".format(time.time()-self._test_start.timestamp()),
71                    test.id())
72        finally:
73            self.semaphore.release()
74        super(BBThreadsafeForwardingResult, self)._add_result_with_semaphore(method, test, *args, **kwargs)
75
76class ProxyTestResult:
77    # a very basic TestResult proxy, in order to modify add* calls
78    def __init__(self, target):
79        self.result = target
80        self.failed_tests = 0
81
82    def _addResult(self, method, test, *args, exception = False, **kwargs):
83        return method(test, *args, **kwargs)
84
85    def addError(self, test, err = None, **kwargs):
86        self.failed_tests += 1
87        self._addResult(self.result.addError, test, err, exception = True, **kwargs)
88
89    def addFailure(self, test, err = None, **kwargs):
90        self.failed_tests += 1
91        self._addResult(self.result.addFailure, test, err, exception = True, **kwargs)
92
93    def addSuccess(self, test, **kwargs):
94        self._addResult(self.result.addSuccess, test, **kwargs)
95
96    def addExpectedFailure(self, test, err = None, **kwargs):
97        self._addResult(self.result.addExpectedFailure, test, err, exception = True, **kwargs)
98
99    def addUnexpectedSuccess(self, test, **kwargs):
100        self._addResult(self.result.addUnexpectedSuccess, test, **kwargs)
101
102    def wasSuccessful(self):
103        return self.failed_tests == 0
104
105    def __getattr__(self, attr):
106        return getattr(self.result, attr)
107
108class ExtraResultsDecoderTestResult(ProxyTestResult):
109    def _addResult(self, method, test, *args, exception = False, **kwargs):
110        if "details" in kwargs and "extraresults" in kwargs["details"]:
111            if isinstance(kwargs["details"]["extraresults"], Content):
112                kwargs = kwargs.copy()
113                kwargs["details"] = kwargs["details"].copy()
114                extraresults = kwargs["details"]["extraresults"]
115                data = bytearray()
116                for b in extraresults.iter_bytes():
117                    data += b
118                extraresults = json.loads(data.decode())
119                kwargs["details"]["extraresults"] = extraresults
120        return method(test, *args, **kwargs)
121
122class ExtraResultsEncoderTestResult(ProxyTestResult):
123    def _addResult(self, method, test, *args, exception = False, **kwargs):
124        if hasattr(test, "extraresults"):
125            extras = lambda : [json.dumps(test.extraresults).encode()]
126            kwargs = kwargs.copy()
127            if "details" not in kwargs:
128                kwargs["details"] = {}
129            else:
130                kwargs["details"] = kwargs["details"].copy()
131            kwargs["details"]["extraresults"] = Content(ContentType("application", "json", {'charset': 'utf8'}), extras)
132        # if using details, need to encode any exceptions into the details obj,
133        # testtools does not handle "err" and "details" together.
134        if "details" in kwargs and exception and (len(args) >= 1 and args[0] is not None):
135            kwargs["details"]["traceback"] = testtools.content.TracebackContent(args[0], test)
136            args = []
137        return method(test, *args, **kwargs)
138
139#
140# We have to patch subunit since it doesn't understand how to handle addError
141# outside of a running test case. This can happen if classSetUp() fails
142# for a class of tests. This unfortunately has horrible internal knowledge.
143#
144def outSideTestaddError(self, offset, line):
145    """An 'error:' directive has been read."""
146    test_name = line[offset:-1].decode('utf8')
147    self.parser._current_test = subunit.RemotedTestCase(test_name)
148    self.parser.current_test_description = test_name
149    self.parser._state = self.parser._reading_error_details
150    self.parser._reading_error_details.set_simple()
151    self.parser.subunitLineReceived(line)
152
153subunit._OutSideTest.addError = outSideTestaddError
154
155# Like outSideTestaddError above, we need an equivalent for skips
156# happening at the setUpClass() level, otherwise we will see "UNKNOWN"
157# as a result for concurrent tests
158#
159def outSideTestaddSkip(self, offset, line):
160    """A 'skip:' directive has been read."""
161    test_name = line[offset:-1].decode('utf8')
162    self.parser._current_test = subunit.RemotedTestCase(test_name)
163    self.parser.current_test_description = test_name
164    self.parser._state = self.parser._reading_skip_details
165    self.parser._reading_skip_details.set_simple()
166    self.parser.subunitLineReceived(line)
167
168subunit._OutSideTest.addSkip = outSideTestaddSkip
169
170#
171# A dummy structure to add to io.StringIO so that the .buffer object
172# is available and accepts writes. This allows unittest with buffer=True
173# to interact ok with subunit which wants to access sys.stdout.buffer.
174#
175class dummybuf(object):
176   def __init__(self, parent):
177       self.p = parent
178   def write(self, data):
179       self.p.write(data.decode("utf-8"))
180
181#
182# Taken from testtools.ConncurrencyTestSuite but modified for OE use
183#
184class ConcurrentTestSuite(unittest.TestSuite):
185
186    def __init__(self, suite, processes, setupfunc, removefunc):
187        super(ConcurrentTestSuite, self).__init__([suite])
188        self.processes = processes
189        self.setupfunc = setupfunc
190        self.removefunc = removefunc
191
192    def run(self, result):
193        tests, totaltests = fork_for_tests(self.processes, self)
194        try:
195            threads = {}
196            queue = Queue()
197            semaphore = threading.Semaphore(1)
198            result.threadprogress = {}
199            for i, (test, testnum) in enumerate(tests):
200                result.threadprogress[i] = []
201                process_result = BBThreadsafeForwardingResult(
202                        ExtraResultsDecoderTestResult(result),
203                        semaphore, i, testnum, totaltests)
204                # Force buffering of stdout/stderr so the console doesn't get corrupted by test output
205                # as per default in parent code
206                process_result.buffer = True
207                # We have to add a buffer object to stdout to keep subunit happy
208                process_result._stderr_buffer = io.StringIO()
209                process_result._stderr_buffer.buffer = dummybuf(process_result._stderr_buffer)
210                process_result._stdout_buffer = io.StringIO()
211                process_result._stdout_buffer.buffer = dummybuf(process_result._stdout_buffer)
212                reader_thread = threading.Thread(
213                    target=self._run_test, args=(test, process_result, queue))
214                threads[test] = reader_thread, process_result
215                reader_thread.start()
216            while threads:
217                finished_test = queue.get()
218                threads[finished_test][0].join()
219                del threads[finished_test]
220        except:
221            for thread, process_result in threads.values():
222                process_result.stop()
223            raise
224        finally:
225            for test in tests:
226                test[0]._stream.close()
227
228    def _run_test(self, test, process_result, queue):
229        try:
230            try:
231                test.run(process_result)
232            except Exception:
233                # The run logic itself failed
234                case = testtools.ErrorHolder(
235                    "broken-runner",
236                    error=sys.exc_info())
237                case.run(process_result)
238        finally:
239            queue.put(test)
240
241def fork_for_tests(concurrency_num, suite):
242    result = []
243    if 'BUILDDIR' in os.environ:
244        selftestdir = get_test_layer()
245
246    test_blocks = partition_tests(suite, concurrency_num)
247    # Clear the tests from the original suite so it doesn't keep them alive
248    suite._tests[:] = []
249    totaltests = sum(len(x) for x in test_blocks)
250    for process_tests in test_blocks:
251        numtests = len(process_tests)
252        process_suite = unittest.TestSuite(process_tests)
253        # Also clear each split list so new suite has only reference
254        process_tests[:] = []
255        c2pread, c2pwrite = os.pipe()
256        # Clear buffers before fork to avoid duplicate output
257        sys.stdout.flush()
258        sys.stderr.flush()
259        pid = os.fork()
260        if pid == 0:
261            ourpid = os.getpid()
262            try:
263                newbuilddir = None
264                stream = os.fdopen(c2pwrite, 'wb', 1)
265                os.close(c2pread)
266
267                (builddir, newbuilddir) = suite.setupfunc("-st-" + str(ourpid), selftestdir, process_suite)
268
269                # Leave stderr and stdout open so we can see test noise
270                # Close stdin so that the child goes away if it decides to
271                # read from stdin (otherwise its a roulette to see what
272                # child actually gets keystrokes for pdb etc).
273                newsi = os.open(os.devnull, os.O_RDWR)
274                os.dup2(newsi, sys.stdin.fileno())
275
276                subunit_client = TestProtocolClient(stream)
277                # Force buffering of stdout/stderr so the console doesn't get corrupted by test output
278                # as per default in parent code
279                subunit_client.buffer = True
280                subunit_result = AutoTimingTestResultDecorator(subunit_client)
281                unittest_result = process_suite.run(ExtraResultsEncoderTestResult(subunit_result))
282                if ourpid != os.getpid():
283                    os._exit(0)
284                if newbuilddir and unittest_result.wasSuccessful():
285                    suite.removefunc(newbuilddir)
286            except:
287                # Don't do anything with process children
288                if ourpid != os.getpid():
289                    os._exit(1)
290                # Try and report traceback on stream, but exit with error
291                # even if stream couldn't be created or something else
292                # goes wrong.  The traceback is formatted to a string and
293                # written in one go to avoid interleaving lines from
294                # multiple failing children.
295                try:
296                    stream.write(traceback.format_exc().encode('utf-8'))
297                except:
298                    sys.stderr.write(traceback.format_exc())
299                finally:
300                    if newbuilddir:
301                        suite.removefunc(newbuilddir)
302                    stream.flush()
303                    os._exit(1)
304            stream.flush()
305            os._exit(0)
306        else:
307            os.close(c2pwrite)
308            stream = os.fdopen(c2pread, 'rb', 1)
309            test = ProtocolTestCase(stream)
310            result.append((test, numtests))
311    return result, totaltests
312
313def partition_tests(suite, count):
314    # Keep tests from the same class together but allow tests from modules
315    # to go to different processes to aid parallelisation.
316    modules = {}
317    for test in iterate_tests(suite):
318        m = test.__module__ + "." + test.__class__.__name__
319        if m not in modules:
320            modules[m] = []
321        modules[m].append(test)
322
323    # Simply divide the test blocks between the available processes
324    partitions = [list() for _ in range(count)]
325    for partition, m in zip(cycle(partitions), modules):
326        partition.extend(modules[m])
327
328    # No point in empty threads so drop them
329    return [p for p in partitions if p]
330
331