1#
2# Copyright (C) 2013 Intel Corporation
3#
4# SPDX-License-Identifier: MIT
5#
6
7# Some custom decorators that can be used by unittests
8# Most useful is skipUnlessPassed which can be used for
9# creating dependecies between two test methods.
10
11import os
12import logging
13import sys
14import unittest
15import threading
16import signal
17from functools import wraps
18
19class testcase(object):
20    def __init__(self, test_case):
21        self.test_case = test_case
22
23    def __call__(self, func):
24        @wraps(func)
25        def wrapped_f(*args, **kwargs):
26            return func(*args, **kwargs)
27        wrapped_f.test_case = self.test_case
28        wrapped_f.__name__ = func.__name__
29        return wrapped_f
30
31class NoParsingFilter(logging.Filter):
32    def filter(self, record):
33        return record.levelno == 100
34
35import inspect
36
37def LogResults(original_class):
38    orig_method = original_class.run
39
40    from time import strftime, gmtime
41    caller = os.path.basename(sys.argv[0])
42    timestamp = strftime('%Y%m%d%H%M%S',gmtime())
43    logfile = os.path.join(os.getcwd(),'results-'+caller+'.'+timestamp+'.log')
44    linkfile = os.path.join(os.getcwd(),'results-'+caller+'.log')
45
46    def get_class_that_defined_method(meth):
47        if inspect.ismethod(meth):
48            for cls in inspect.getmro(meth.__self__.__class__):
49               if cls.__dict__.get(meth.__name__) is meth:
50                    return cls
51            meth = meth.__func__ # fallback to __qualname__ parsing
52        if inspect.isfunction(meth):
53            cls = getattr(inspect.getmodule(meth),
54                          meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
55            if isinstance(cls, type):
56               return cls
57        return None
58
59    #rewrite the run method of unittest.TestCase to add testcase logging
60    def run(self, result, *args, **kws):
61        orig_method(self, result, *args, **kws)
62        passed = True
63        testMethod = getattr(self, self._testMethodName)
64        #if test case is decorated then use it's number, else use it's name
65        try:
66            test_case = testMethod.test_case
67        except AttributeError:
68            test_case = self._testMethodName
69
70        class_name = str(get_class_that_defined_method(testMethod)).split("'")[1]
71
72        #create custom logging level for filtering.
73        custom_log_level = 100
74        logging.addLevelName(custom_log_level, 'RESULTS')
75
76        def results(self, message, *args, **kws):
77            if self.isEnabledFor(custom_log_level):
78                self.log(custom_log_level, message, *args, **kws)
79        logging.Logger.results = results
80
81        logging.basicConfig(filename=logfile,
82                            filemode='w',
83                            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
84                            datefmt='%H:%M:%S',
85                            level=custom_log_level)
86        for handler in logging.root.handlers:
87            handler.addFilter(NoParsingFilter())
88        local_log = logging.getLogger(caller)
89
90        #check status of tests and record it
91
92        tcid = self.id()
93        for (name, msg) in result.errors:
94            if tcid == name.id():
95                local_log.results("Testcase "+str(test_case)+": ERROR")
96                local_log.results("Testcase "+str(test_case)+":\n"+msg)
97                passed = False
98        for (name, msg) in result.failures:
99            if tcid == name.id():
100                local_log.results("Testcase "+str(test_case)+": FAILED")
101                local_log.results("Testcase "+str(test_case)+":\n"+msg)
102                passed = False
103        for (name, msg) in result.skipped:
104            if tcid == name.id():
105                local_log.results("Testcase "+str(test_case)+": SKIPPED")
106                passed = False
107        if passed:
108            local_log.results("Testcase "+str(test_case)+": PASSED")
109
110        # XXX: In order to avoid race condition when test if exists the linkfile
111        # use bb.utils.lock, the best solution is to create a unique name for the
112        # link file.
113        try:
114            import bb
115            has_bb = True
116            lockfilename = linkfile + '.lock'
117        except ImportError:
118            has_bb = False
119
120        if has_bb:
121            lf = bb.utils.lockfile(lockfilename, block=True)
122        # Create symlink to the current log
123        if os.path.lexists(linkfile):
124            os.remove(linkfile)
125        os.symlink(logfile, linkfile)
126        if has_bb:
127            bb.utils.unlockfile(lf)
128
129    original_class.run = run
130
131    return original_class
132
133class TimeOut(BaseException):
134    pass
135
136def timeout(seconds):
137    def decorator(fn):
138        if hasattr(signal, 'alarm'):
139            @wraps(fn)
140            def wrapped_f(*args, **kw):
141                current_frame = sys._getframe()
142                def raiseTimeOut(signal, frame):
143                    if frame is not current_frame:
144                        raise TimeOut('%s seconds' % seconds)
145                prev_handler = signal.signal(signal.SIGALRM, raiseTimeOut)
146                try:
147                    signal.alarm(seconds)
148                    return fn(*args, **kw)
149                finally:
150                    signal.alarm(0)
151                    signal.signal(signal.SIGALRM, prev_handler)
152            return wrapped_f
153        else:
154            return fn
155    return decorator
156
157__tag_prefix = "tag__"
158def tag(*args, **kwargs):
159    """Decorator that adds attributes to classes or functions
160    for use with the Attribute (-a) plugin.
161    """
162    def wrap_ob(ob):
163        for name in args:
164            setattr(ob, __tag_prefix + name, True)
165        for name, value in kwargs.items():
166            setattr(ob, __tag_prefix + name, value)
167        return ob
168    return wrap_ob
169
170def gettag(obj, key, default=None):
171    key = __tag_prefix + key
172    if not isinstance(obj, unittest.TestCase):
173        return getattr(obj, key, default)
174    tc_method = getattr(obj, obj._testMethodName)
175    ret = getattr(tc_method, key, getattr(obj, key, default))
176    return ret
177
178def getAllTags(obj):
179    def __gettags(o):
180        r = {k[len(__tag_prefix):]:getattr(o,k) for k in dir(o) if k.startswith(__tag_prefix)}
181        return r
182    if not isinstance(obj, unittest.TestCase):
183        return __gettags(obj)
184    tc_method = getattr(obj, obj._testMethodName)
185    ret = __gettags(obj)
186    ret.update(__gettags(tc_method))
187    return ret
188
189def timeout_handler(seconds):
190    def decorator(fn):
191        if hasattr(signal, 'alarm'):
192            @wraps(fn)
193            def wrapped_f(self, *args, **kw):
194                current_frame = sys._getframe()
195                def raiseTimeOut(signal, frame):
196                    if frame is not current_frame:
197                        try:
198                            self.target.restart()
199                            raise TimeOut('%s seconds' % seconds)
200                        except:
201                            raise TimeOut('%s seconds' % seconds)
202                prev_handler = signal.signal(signal.SIGALRM, raiseTimeOut)
203                try:
204                    signal.alarm(seconds)
205                    return fn(self, *args, **kw)
206                finally:
207                    signal.alarm(0)
208                    signal.signal(signal.SIGALRM, prev_handler)
209            return wrapped_f
210        else:
211            return fn
212    return decorator
213