xref: /openbmc/openbmc/poky/meta/lib/oeqa/core/utils/concurrencytest.py (revision 2013739591dc50e6d01836d0017e7e5a02225709)
1#!/usr/bin/env python3
2#
3# Copyright OpenEmbedded Contributors
4#
5# SPDX-License-Identifier: GPL-2.0-or-later
6#
7# Modified for use in OE by Richard Purdie, 2018
8#
9# Modified by: Corey Goldberg, 2013
10#   License: GPLv2+
11#
12# Original code from:
13#   Bazaar (bzrlib.tests.__init__.py, v2.6, copied Jun 01 2013)
14#   Copyright (C) 2005-2011 Canonical Ltd
15#   License: GPLv2+
16
17import os
18import sys
19import traceback
20import unittest
21import subprocess
22import testtools
23import threading
24import time
25import io
26import json
27import subunit
28
29from queue import Queue
30from itertools import cycle
31from subunit import ProtocolTestCase, TestProtocolClient
32from subunit.test_results import AutoTimingTestResultDecorator
33from testtools import ThreadsafeForwardingResult, iterate_tests
34from testtools.content import Content
35from testtools.content_type import ContentType
36from oeqa.utils.commands import get_test_layer
37
38import bb.utils
39import oe.path
40
41_all__ = [
42    'ConcurrentTestSuite',
43    'fork_for_tests',
44    'partition_tests',
45]
46
47#
48# Patch the version from testtools to allow access to _test_start and allow
49# computation of timing information and threading progress
50#
51class BBThreadsafeForwardingResult(ThreadsafeForwardingResult):
52
53    def __init__(self, target, semaphore, threadnum, totalinprocess, totaltests, output, finalresult):
54        super(BBThreadsafeForwardingResult, self).__init__(target, semaphore)
55        self.threadnum = threadnum
56        self.totalinprocess = totalinprocess
57        self.totaltests = totaltests
58        self.buffer = True
59        self.outputbuf = output
60        self.finalresult = finalresult
61        self.finalresult.buffer = True
62        self.target = target
63
64    def _add_result_with_semaphore(self, method, test, *args, **kwargs):
65        self.semaphore.acquire()
66        try:
67            if self._test_start:
68                self.result.starttime[test.id()] = self._test_start.timestamp()
69                self.result.threadprogress[self.threadnum].append(test.id())
70                totalprogress = sum(len(x) for x in self.result.threadprogress.values())
71                self.result.progressinfo[test.id()] = "%s: %s/%s %s/%s (%ss) (%s failed) (%s)" % (
72                    self.threadnum,
73                    len(self.result.threadprogress[self.threadnum]),
74                    self.totalinprocess,
75                    totalprogress,
76                    self.totaltests,
77                    "{0:.2f}".format(time.time()-self._test_start.timestamp()),
78                    self.target.failed_tests,
79                    test.id())
80        finally:
81            self.semaphore.release()
82        self.finalresult._stderr_buffer = io.StringIO(initial_value=self.outputbuf.getvalue().decode("utf-8"))
83        self.finalresult._stdout_buffer = io.StringIO()
84        super(BBThreadsafeForwardingResult, self)._add_result_with_semaphore(method, test, *args, **kwargs)
85
86class ProxyTestResult:
87    # a very basic TestResult proxy, in order to modify add* calls
88    def __init__(self, target):
89        self.result = target
90        self.failed_tests = 0
91
92    def _addResult(self, method, test, *args, exception = False, **kwargs):
93        return method(test, *args, **kwargs)
94
95    def addError(self, test, err = None, **kwargs):
96        self.failed_tests += 1
97        self._addResult(self.result.addError, test, err, exception = True, **kwargs)
98
99    def addFailure(self, test, err = None, **kwargs):
100        self.failed_tests += 1
101        self._addResult(self.result.addFailure, test, err, exception = True, **kwargs)
102
103    def addSuccess(self, test, **kwargs):
104        self._addResult(self.result.addSuccess, test, **kwargs)
105
106    def addExpectedFailure(self, test, err = None, **kwargs):
107        self._addResult(self.result.addExpectedFailure, test, err, exception = True, **kwargs)
108
109    def addUnexpectedSuccess(self, test, **kwargs):
110        self._addResult(self.result.addUnexpectedSuccess, test, **kwargs)
111
112    def wasSuccessful(self):
113        return self.failed_tests == 0
114
115    def __getattr__(self, attr):
116        return getattr(self.result, attr)
117
118class ExtraResultsDecoderTestResult(ProxyTestResult):
119    def _addResult(self, method, test, *args, exception = False, **kwargs):
120        if "details" in kwargs and "extraresults" in kwargs["details"]:
121            if isinstance(kwargs["details"]["extraresults"], Content):
122                kwargs = kwargs.copy()
123                kwargs["details"] = kwargs["details"].copy()
124                extraresults = kwargs["details"]["extraresults"]
125                data = bytearray()
126                for b in extraresults.iter_bytes():
127                    data += b
128                extraresults = json.loads(data.decode())
129                kwargs["details"]["extraresults"] = extraresults
130        return method(test, *args, **kwargs)
131
132class ExtraResultsEncoderTestResult(ProxyTestResult):
133    def _addResult(self, method, test, *args, exception = False, **kwargs):
134        if hasattr(test, "extraresults"):
135            extras = lambda : [json.dumps(test.extraresults).encode()]
136            kwargs = kwargs.copy()
137            if "details" not in kwargs:
138                kwargs["details"] = {}
139            else:
140                kwargs["details"] = kwargs["details"].copy()
141            kwargs["details"]["extraresults"] = Content(ContentType("application", "json", {'charset': 'utf8'}), extras)
142        # if using details, need to encode any exceptions into the details obj,
143        # testtools does not handle "err" and "details" together.
144        if "details" in kwargs and exception and (len(args) >= 1 and args[0] is not None):
145            kwargs["details"]["traceback"] = testtools.content.TracebackContent(args[0], test)
146            args = []
147        return method(test, *args, **kwargs)
148
149#
150# We have to patch subunit since it doesn't understand how to handle addError
151# outside of a running test case. This can happen if classSetUp() fails
152# for a class of tests. This unfortunately has horrible internal knowledge.
153#
154def outSideTestaddError(self, offset, line):
155    """An 'error:' directive has been read."""
156    test_name = line[offset:-1].decode('utf8')
157    self.parser._current_test = subunit.RemotedTestCase(test_name)
158    self.parser.current_test_description = test_name
159    self.parser._state = self.parser._reading_error_details
160    self.parser._reading_error_details.set_simple()
161    self.parser.subunitLineReceived(line)
162
163subunit._OutSideTest.addError = outSideTestaddError
164
165# Like outSideTestaddError above, we need an equivalent for skips
166# happening at the setUpClass() level, otherwise we will see "UNKNOWN"
167# as a result for concurrent tests
168#
169def outSideTestaddSkip(self, offset, line):
170    """A 'skip:' directive has been read."""
171    test_name = line[offset:-1].decode('utf8')
172    self.parser._current_test = subunit.RemotedTestCase(test_name)
173    self.parser.current_test_description = test_name
174    self.parser._state = self.parser._reading_skip_details
175    self.parser._reading_skip_details.set_simple()
176    self.parser.subunitLineReceived(line)
177
178subunit._OutSideTest.addSkip = outSideTestaddSkip
179
180#
181# A dummy structure to add to io.StringIO so that the .buffer object
182# is available and accepts writes. This allows unittest with buffer=True
183# to interact ok with subunit which wants to access sys.stdout.buffer.
184#
185class dummybuf(object):
186   def __init__(self, parent):
187       self.p = parent
188   def write(self, data):
189       self.p.write(data.decode("utf-8"))
190
191#
192# Taken from testtools.ConncurrencyTestSuite but modified for OE use
193#
194class ConcurrentTestSuite(unittest.TestSuite):
195
196    def __init__(self, suite, processes, setupfunc, removefunc, bb_vars):
197        super(ConcurrentTestSuite, self).__init__([suite])
198        self.processes = processes
199        self.setupfunc = setupfunc
200        self.removefunc = removefunc
201        self.bb_vars = bb_vars
202
203    def run(self, result):
204        testservers, totaltests = fork_for_tests(self.processes, self)
205        try:
206            threads = {}
207            queue = Queue()
208            semaphore = threading.Semaphore(1)
209            result.threadprogress = {}
210            for i, (testserver, testnum, output) in enumerate(testservers):
211                result.threadprogress[i] = []
212                process_result = BBThreadsafeForwardingResult(
213                        ExtraResultsDecoderTestResult(result),
214                        semaphore, i, testnum, totaltests, output, result)
215                reader_thread = threading.Thread(
216                    target=self._run_test, args=(testserver, process_result, queue))
217                threads[testserver] = reader_thread, process_result
218                reader_thread.start()
219            while threads:
220                finished_test = queue.get()
221                threads[finished_test][0].join()
222                del threads[finished_test]
223        except:
224            for thread, process_result in threads.values():
225                process_result.stop()
226            raise
227        finally:
228            for testserver in testservers:
229                testserver[0]._stream.close()
230
231    def _run_test(self, testserver, process_result, queue):
232        try:
233            try:
234                testserver.run(process_result)
235            except Exception:
236                # The run logic itself failed
237                case = testtools.ErrorHolder(
238                    "broken-runner",
239                    error=sys.exc_info())
240                case.run(process_result)
241        finally:
242            queue.put(testserver)
243
244def fork_for_tests(concurrency_num, suite):
245    testservers = []
246    if 'BUILDDIR' in os.environ:
247        selftestdir = get_test_layer(suite.bb_vars['BBLAYERS'])
248
249    test_blocks = partition_tests(suite, concurrency_num)
250    # Clear the tests from the original suite so it doesn't keep them alive
251    suite._tests[:] = []
252    totaltests = sum(len(x) for x in test_blocks)
253    for process_tests in test_blocks:
254        numtests = len(process_tests)
255        process_suite = unittest.TestSuite(process_tests)
256        # Also clear each split list so new suite has only reference
257        process_tests[:] = []
258        c2pread, c2pwrite = os.pipe()
259        # Clear buffers before fork to avoid duplicate output
260        sys.stdout.flush()
261        sys.stderr.flush()
262        pid = os.fork()
263        if pid == 0:
264            ourpid = os.getpid()
265            try:
266                newbuilddir = None
267                stream = os.fdopen(c2pwrite, 'wb')
268                os.close(c2pread)
269
270                (builddir, newbuilddir) = suite.setupfunc("-st-" + str(ourpid), selftestdir, process_suite)
271
272                # Leave stderr and stdout open so we can see test noise
273                # Close stdin so that the child goes away if it decides to
274                # read from stdin (otherwise its a roulette to see what
275                # child actually gets keystrokes for pdb etc).
276                newsi = os.open(os.devnull, os.O_RDWR)
277                os.dup2(newsi, sys.stdin.fileno())
278
279                # Send stdout/stderr over the stream
280                os.dup2(c2pwrite, sys.stdout.fileno())
281                os.dup2(c2pwrite, sys.stderr.fileno())
282
283                subunit_client = TestProtocolClient(stream)
284                subunit_result = AutoTimingTestResultDecorator(subunit_client)
285                unittest_result = process_suite.run(ExtraResultsEncoderTestResult(subunit_result))
286                if ourpid != os.getpid():
287                    os._exit(0)
288                if newbuilddir and unittest_result.wasSuccessful():
289                    suite.removefunc(newbuilddir)
290            except:
291                # Don't do anything with process children
292                if ourpid != os.getpid():
293                    os._exit(1)
294                # Try and report traceback on stream, but exit with error
295                # even if stream couldn't be created or something else
296                # goes wrong.  The traceback is formatted to a string and
297                # written in one go to avoid interleaving lines from
298                # multiple failing children.
299                try:
300                    stream.write(traceback.format_exc().encode('utf-8'))
301                except:
302                    sys.stderr.write(traceback.format_exc())
303                finally:
304                    if newbuilddir:
305                        suite.removefunc(newbuilddir)
306                    stream.flush()
307                    os._exit(1)
308            stream.flush()
309            os._exit(0)
310        else:
311            os.close(c2pwrite)
312            stream = os.fdopen(c2pread, 'rb')
313            # Collect stdout/stderr into an io buffer
314            output = io.BytesIO()
315            testserver = ProtocolTestCase(stream, passthrough=output)
316            testservers.append((testserver, numtests, output))
317    return testservers, totaltests
318
319def partition_tests(suite, count):
320    # Keep tests from the same class together but allow tests from modules
321    # to go to different processes to aid parallelisation.
322    modules = {}
323    for test in iterate_tests(suite):
324        m = test.__module__ + "." + test.__class__.__name__
325        if m not in modules:
326            modules[m] = []
327        modules[m].append(test)
328
329    # Simply divide the test blocks between the available processes
330    partitions = [list() for _ in range(count)]
331    for partition, m in zip(cycle(partitions), modules):
332        partition.extend(modules[m])
333
334    # No point in empty threads so drop them
335    return [p for p in partitions if p]
336
337