xref: /openbmc/openbmc/poky/meta/lib/oeqa/core/loader.py (revision c7f50e73)
1#
2# Copyright (C) 2016 Intel Corporation
3#
4# SPDX-License-Identifier: MIT
5#
6
7import os
8import re
9import sys
10import unittest
11import inspect
12
13from oeqa.core.utils.path import findFile
14from oeqa.core.utils.test import getSuiteModules, getCaseID
15
16from oeqa.core.exception import OEQATestNotFound
17from oeqa.core.case import OETestCase
18from oeqa.core.decorator import decoratorClasses, OETestDecorator, \
19        OETestDiscover
20
21# When loading tests, the unittest framework stores any exceptions and
22# displays them only when the run method is called.
23#
24# For our purposes, it is better to raise the exceptions in the loading
25# step rather than waiting to run the test suite.
26#
27# Generate the function definition because this differ across python versions
28# Python >= 3.4.4 uses tree parameters instead four but for example Python 3.5.3
29# ueses four parameters so isn't incremental.
30_failed_test_args = inspect.getfullargspec(unittest.loader._make_failed_test).args
31exec("""def _make_failed_test(%s): raise exception""" % ', '.join(_failed_test_args))
32unittest.loader._make_failed_test = _make_failed_test
33
34def _find_duplicated_modules(suite, directory):
35    for module in getSuiteModules(suite):
36        path = findFile('%s.py' % module, directory)
37        if path:
38            raise ImportError("Duplicated %s module found in %s" % (module, path))
39
40def _built_modules_dict(modules):
41    modules_dict = {}
42
43    if modules == None:
44        return modules_dict
45
46    for module in modules:
47        # Assumption: package and module names do not contain upper case
48        # characters, whereas class names do
49        m = re.match(r'^([0-9a-z_.]+)(?:\.(\w[^.]*)(?:\.([^.]+))?)?$', module, flags=re.ASCII)
50        if not m:
51            continue
52
53        module_name, class_name, test_name = m.groups()
54
55        if module_name and module_name not in modules_dict:
56            modules_dict[module_name] = {}
57        if class_name and class_name not in modules_dict[module_name]:
58            modules_dict[module_name][class_name] = []
59        if test_name and test_name not in modules_dict[module_name][class_name]:
60            modules_dict[module_name][class_name].append(test_name)
61
62    return modules_dict
63
64class OETestLoader(unittest.TestLoader):
65    caseClass = OETestCase
66
67    kwargs_names = ['testMethodPrefix', 'sortTestMethodUsing', 'suiteClass',
68            '_top_level_dir']
69
70    def __init__(self, tc, module_paths, modules, tests, modules_required,
71            *args, **kwargs):
72        self.tc = tc
73
74        self.modules = _built_modules_dict(modules)
75
76        self.tests = tests
77        self.modules_required = modules_required
78
79        self.tags_filter = kwargs.get("tags_filter", None)
80
81        if isinstance(module_paths, str):
82            module_paths = [module_paths]
83        elif not isinstance(module_paths, list):
84            raise TypeError('module_paths must be a str or a list of str')
85        self.module_paths = module_paths
86
87        for kwname in self.kwargs_names:
88            if kwname in kwargs:
89                setattr(self, kwname, kwargs[kwname])
90
91        self._patchCaseClass(self.caseClass)
92
93        super(OETestLoader, self).__init__()
94
95    def _patchCaseClass(self, testCaseClass):
96        # Adds custom attributes to the OETestCase class
97        setattr(testCaseClass, 'tc', self.tc)
98        setattr(testCaseClass, 'td', self.tc.td)
99        setattr(testCaseClass, 'logger', self.tc.logger)
100
101    def _registerTestCase(self, case):
102        case_id = case.id()
103        self.tc._registry['cases'][case_id] = case
104
105    def _handleTestCaseDecorators(self, case):
106        def _handle(obj):
107            if isinstance(obj, OETestDecorator):
108                if not obj.__class__ in decoratorClasses:
109                    raise Exception("Decorator %s isn't registered" \
110                            " in decoratorClasses." % obj.__name__)
111                obj.bind(self.tc._registry, case)
112
113        def _walk_closure(obj):
114            if hasattr(obj, '__closure__') and obj.__closure__:
115                for f in obj.__closure__:
116                    obj = f.cell_contents
117                    _handle(obj)
118                    _walk_closure(obj)
119        method = getattr(case, case._testMethodName, None)
120        _walk_closure(method)
121
122    def _filterTest(self, case):
123        """
124            Returns True if test case must be filtered, False otherwise.
125        """
126        # XXX; If the module has more than one namespace only use
127        # the first to support run the whole module specifying the
128        # <module_name>.[test_class].[test_name]
129        module_name_small = case.__module__.split('.')[0]
130        module_name = case.__module__
131
132        class_name = case.__class__.__name__
133        test_name = case._testMethodName
134
135        # 'auto' is a reserved key word to run test cases automatically
136        # warn users if their test case belong to a module named 'auto'
137        if module_name_small == "auto":
138            bb.warn("'auto' is a reserved key word for TEST_SUITES. "
139                    "But test case '%s' is detected to belong to auto module. "
140                    "Please condier using a new name for your module." % str(case))
141
142        # check if case belongs to any specified module
143        # if 'auto' is specified, such check is skipped
144        if self.modules and not 'auto' in self.modules:
145            module = None
146            try:
147                module = self.modules[module_name_small]
148            except KeyError:
149                try:
150                    module = self.modules[module_name]
151                except KeyError:
152                    return True
153
154            if module:
155                if not class_name in module:
156                    return True
157
158                if module[class_name]:
159                    if test_name not in module[class_name]:
160                        return True
161
162        # Decorator filters
163        if self.tags_filter is not None and callable(self.tags_filter):
164            alltags = set()
165            # pull tags from the case class
166            if hasattr(case, "__oeqa_testtags"):
167                for t in getattr(case, "__oeqa_testtags"):
168                    alltags.add(t)
169            # pull tags from the method itself
170            if hasattr(case, test_name):
171                method = getattr(case, test_name)
172                if hasattr(method, "__oeqa_testtags"):
173                    for t in getattr(method, "__oeqa_testtags"):
174                        alltags.add(t)
175
176            if self.tags_filter(alltags):
177                return True
178
179        return False
180
181    def _getTestCase(self, testCaseClass, tcName):
182        if not hasattr(testCaseClass, '__oeqa_loader') and \
183                issubclass(testCaseClass, OETestCase):
184            # In order to support data_vars validation
185            # monkey patch the default setUp/tearDown{Class} to use
186            # the ones provided by OETestCase
187            setattr(testCaseClass, 'setUpClassMethod',
188                    getattr(testCaseClass, 'setUpClass'))
189            setattr(testCaseClass, 'tearDownClassMethod',
190                    getattr(testCaseClass, 'tearDownClass'))
191            setattr(testCaseClass, 'setUpClass',
192                    testCaseClass._oeSetUpClass)
193            setattr(testCaseClass, 'tearDownClass',
194                    testCaseClass._oeTearDownClass)
195
196            # In order to support decorators initialization
197            # monkey patch the default setUp/tearDown to use
198            # a setUpDecorators/tearDownDecorators that methods
199            # will call setUp/tearDown original methods.
200            setattr(testCaseClass, 'setUpMethod',
201                    getattr(testCaseClass, 'setUp'))
202            setattr(testCaseClass, 'tearDownMethod',
203                    getattr(testCaseClass, 'tearDown'))
204            setattr(testCaseClass, 'setUp', testCaseClass._oeSetUp)
205            setattr(testCaseClass, 'tearDown', testCaseClass._oeTearDown)
206
207            setattr(testCaseClass, '__oeqa_loader', True)
208
209        case = testCaseClass(tcName)
210        if isinstance(case, OETestCase):
211            setattr(case, 'decorators', [])
212
213        return case
214
215    def loadTestsFromTestCase(self, testCaseClass):
216        """
217            Returns a suite of all tests cases contained in testCaseClass.
218        """
219        if issubclass(testCaseClass, unittest.suite.TestSuite):
220            raise TypeError("Test cases should not be derived from TestSuite." \
221                                " Maybe you meant to derive %s from TestCase?" \
222                                % testCaseClass.__name__)
223        if not issubclass(testCaseClass, unittest.case.TestCase):
224            raise TypeError("Test %s is not derived from %s" % \
225                    (testCaseClass.__name__, unittest.case.TestCase.__name__))
226
227        testCaseNames = self.getTestCaseNames(testCaseClass)
228        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
229            testCaseNames = ['runTest']
230
231        suite = []
232        for tcName in testCaseNames:
233            case = self._getTestCase(testCaseClass, tcName)
234            # Filer by case id
235            if not (self.tests and not 'auto' in self.tests
236                    and not getCaseID(case) in self.tests):
237                self._handleTestCaseDecorators(case)
238
239                # Filter by decorators
240                if not self._filterTest(case):
241                    self._registerTestCase(case)
242                    suite.append(case)
243
244        return self.suiteClass(suite)
245
246    def _required_modules_validation(self):
247        """
248            Search in Test context registry if a required
249            test is found, raise an exception when not found.
250        """
251
252        for module in self.modules_required:
253            found = False
254
255            # The module name is splitted to only compare the
256            # first part of a test case id.
257            comp_len = len(module.split('.'))
258            for case in self.tc._registry['cases']:
259                case_comp = '.'.join(case.split('.')[0:comp_len])
260                if module == case_comp:
261                    found = True
262                    break
263
264            if not found:
265                raise OEQATestNotFound("Not found %s in loaded test cases" % \
266                        module)
267
268    def discover(self):
269        big_suite = self.suiteClass()
270        for path in self.module_paths:
271            _find_duplicated_modules(big_suite, path)
272            suite = super(OETestLoader, self).discover(path,
273                    pattern='*.py', top_level_dir=path)
274            big_suite.addTests(suite)
275
276        cases = None
277        discover_classes = [clss for clss in decoratorClasses
278                            if issubclass(clss, OETestDiscover)]
279        for clss in discover_classes:
280            cases = clss.discover(self.tc._registry)
281
282        if self.modules_required:
283            self._required_modules_validation()
284
285        return self.suiteClass(cases) if cases else big_suite
286
287    def _filterModule(self, module):
288        if module.__name__ in sys.builtin_module_names:
289            msg = 'Tried to import %s test module but is a built-in'
290            raise ImportError(msg % module.__name__)
291
292        # XXX; If the module has more than one namespace only use
293        # the first to support run the whole module specifying the
294        # <module_name>.[test_class].[test_name]
295        module_name_small = module.__name__.split('.')[0]
296        module_name = module.__name__
297
298        # Normal test modules are loaded if no modules were specified,
299        # if module is in the specified module list or if 'auto' is in
300        # module list.
301        # Underscore modules are loaded only if specified in module list.
302        load_module = True if not module_name.startswith('_') \
303                              and (not self.modules \
304                                   or module_name in self.modules \
305                                   or module_name_small in self.modules \
306                                   or 'auto' in self.modules) \
307                           else False
308
309        load_underscore = True if module_name.startswith('_') \
310                                  and (module_name in self.modules or \
311                                  module_name_small in self.modules) \
312                               else False
313
314        return (load_module, load_underscore)
315
316
317    # XXX After Python 3.5, remove backward compatibility hacks for
318    # use_load_tests deprecation via *args and **kws.  See issue 16662.
319    if sys.version_info >= (3,5):
320        def loadTestsFromModule(self, module, *args, pattern=None, **kws):
321            """
322                Returns a suite of all tests cases contained in module.
323            """
324            load_module, load_underscore = self._filterModule(module)
325
326            if load_module or load_underscore:
327                return super(OETestLoader, self).loadTestsFromModule(
328                        module, *args, pattern=pattern, **kws)
329            else:
330                return self.suiteClass()
331    else:
332        def loadTestsFromModule(self, module, use_load_tests=True):
333            """
334                Returns a suite of all tests cases contained in module.
335            """
336            load_module, load_underscore = self._filterModule(module)
337
338            if load_module or load_underscore:
339                return super(OETestLoader, self).loadTestsFromModule(
340                        module, use_load_tests)
341            else:
342                return self.suiteClass()
343