1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3
4"""
5tdc.py - Linux tc (Traffic Control) unit test driver
6
7Copyright (C) 2017 Lucas Bates <lucasb@mojatatu.com>
8"""
9
10import re
11import os
12import sys
13import argparse
14import importlib
15import json
16import subprocess
17import time
18import traceback
19from collections import OrderedDict
20from string import Template
21
22from tdc_config import *
23from tdc_helper import *
24
25import TdcPlugin
26from TdcResults import *
27
28class PluginDependencyException(Exception):
29    def __init__(self, missing_pg):
30        self.missing_pg = missing_pg
31
32class PluginMgrTestFail(Exception):
33    def __init__(self, stage, output, message):
34        self.stage = stage
35        self.output = output
36        self.message = message
37
38class PluginMgr:
39    def __init__(self, argparser):
40        super().__init__()
41        self.plugins = {}
42        self.plugin_instances = []
43        self.failed_plugins = {}
44        self.argparser = argparser
45
46        # TODO, put plugins in order
47        plugindir = os.getenv('TDC_PLUGIN_DIR', './plugins')
48        for dirpath, dirnames, filenames in os.walk(plugindir):
49            for fn in filenames:
50                if (fn.endswith('.py') and
51                    not fn == '__init__.py' and
52                    not fn.startswith('#') and
53                    not fn.startswith('.#')):
54                    mn = fn[0:-3]
55                    foo = importlib.import_module('plugins.' + mn)
56                    self.plugins[mn] = foo
57                    self.plugin_instances.append(foo.SubPlugin())
58
59    def load_plugin(self, pgdir, pgname):
60        pgname = pgname[0:-3]
61        foo = importlib.import_module('{}.{}'.format(pgdir, pgname))
62        self.plugins[pgname] = foo
63        self.plugin_instances.append(foo.SubPlugin())
64        self.plugin_instances[-1].check_args(self.args, None)
65
66    def get_required_plugins(self, testlist):
67        '''
68        Get all required plugins from the list of test cases and return
69        all unique items.
70        '''
71        reqs = []
72        for t in testlist:
73            try:
74                if 'requires' in t['plugins']:
75                    if isinstance(t['plugins']['requires'], list):
76                        reqs.extend(t['plugins']['requires'])
77                    else:
78                        reqs.append(t['plugins']['requires'])
79            except KeyError:
80                continue
81        reqs = get_unique_item(reqs)
82        return reqs
83
84    def load_required_plugins(self, reqs, parser, args, remaining):
85        '''
86        Get all required plugins from the list of test cases and load any plugin
87        that is not already enabled.
88        '''
89        pgd = ['plugin-lib', 'plugin-lib-custom']
90        pnf = []
91
92        for r in reqs:
93            if r not in self.plugins:
94                fname = '{}.py'.format(r)
95                source_path = []
96                for d in pgd:
97                    pgpath = '{}/{}'.format(d, fname)
98                    if os.path.isfile(pgpath):
99                        source_path.append(pgpath)
100                if len(source_path) == 0:
101                    print('ERROR: unable to find required plugin {}'.format(r))
102                    pnf.append(fname)
103                    continue
104                elif len(source_path) > 1:
105                    print('WARNING: multiple copies of plugin {} found, using version found')
106                    print('at {}'.format(source_path[0]))
107                pgdir = source_path[0]
108                pgdir = pgdir.split('/')[0]
109                self.load_plugin(pgdir, fname)
110        if len(pnf) > 0:
111            raise PluginDependencyException(pnf)
112
113        parser = self.call_add_args(parser)
114        (args, remaining) = parser.parse_known_args(args=remaining, namespace=args)
115        return args
116
117    def call_pre_suite(self, testcount, testidlist):
118        for pgn_inst in self.plugin_instances:
119            pgn_inst.pre_suite(testcount, testidlist)
120
121    def call_post_suite(self, index):
122        for pgn_inst in reversed(self.plugin_instances):
123            pgn_inst.post_suite(index)
124
125    def call_pre_case(self, caseinfo, *, test_skip=False):
126        for pgn_inst in self.plugin_instances:
127            try:
128                pgn_inst.pre_case(caseinfo, test_skip)
129            except Exception as ee:
130                print('exception {} in call to pre_case for {} plugin'.
131                      format(ee, pgn_inst.__class__))
132                print('testid is {}'.format(caseinfo['id']))
133                raise
134
135    def call_post_case(self):
136        for pgn_inst in reversed(self.plugin_instances):
137            pgn_inst.post_case()
138
139    def call_pre_execute(self):
140        for pgn_inst in self.plugin_instances:
141            pgn_inst.pre_execute()
142
143    def call_post_execute(self):
144        for pgn_inst in reversed(self.plugin_instances):
145            pgn_inst.post_execute()
146
147    def call_add_args(self, parser):
148        for pgn_inst in self.plugin_instances:
149            parser = pgn_inst.add_args(parser)
150        return parser
151
152    def call_check_args(self, args, remaining):
153        for pgn_inst in self.plugin_instances:
154            pgn_inst.check_args(args, remaining)
155
156    def call_adjust_command(self, stage, command):
157        for pgn_inst in self.plugin_instances:
158            command = pgn_inst.adjust_command(stage, command)
159        return command
160
161    def set_args(self, args):
162        self.args = args
163
164    @staticmethod
165    def _make_argparser(args):
166        self.argparser = argparse.ArgumentParser(
167            description='Linux TC unit tests')
168
169def replace_keywords(cmd):
170    """
171    For a given executable command, substitute any known
172    variables contained within NAMES with the correct values
173    """
174    tcmd = Template(cmd)
175    subcmd = tcmd.safe_substitute(NAMES)
176    return subcmd
177
178
179def exec_cmd(args, pm, stage, command):
180    """
181    Perform any required modifications on an executable command, then run
182    it in a subprocess and return the results.
183    """
184    if len(command.strip()) == 0:
185        return None, None
186    if '$' in command:
187        command = replace_keywords(command)
188
189    command = pm.call_adjust_command(stage, command)
190    if args.verbose > 0:
191        print('command "{}"'.format(command))
192    proc = subprocess.Popen(command,
193        shell=True,
194        stdout=subprocess.PIPE,
195        stderr=subprocess.PIPE,
196        env=ENVIR)
197
198    try:
199        (rawout, serr) = proc.communicate(timeout=NAMES['TIMEOUT'])
200        if proc.returncode != 0 and len(serr) > 0:
201            foutput = serr.decode("utf-8", errors="ignore")
202        else:
203            foutput = rawout.decode("utf-8", errors="ignore")
204    except subprocess.TimeoutExpired:
205        foutput = "Command \"{}\" timed out\n".format(command)
206        proc.returncode = 255
207
208    proc.stdout.close()
209    proc.stderr.close()
210    return proc, foutput
211
212
213def prepare_env(args, pm, stage, prefix, cmdlist, output = None):
214    """
215    Execute the setup/teardown commands for a test case.
216    Optionally terminate test execution if the command fails.
217    """
218    if args.verbose > 0:
219        print('{}'.format(prefix))
220    for cmdinfo in cmdlist:
221        if isinstance(cmdinfo, list):
222            exit_codes = cmdinfo[1:]
223            cmd = cmdinfo[0]
224        else:
225            exit_codes = [0]
226            cmd = cmdinfo
227
228        if not cmd:
229            continue
230
231        (proc, foutput) = exec_cmd(args, pm, stage, cmd)
232
233        if proc and (proc.returncode not in exit_codes):
234            print('', file=sys.stderr)
235            print("{} *** Could not execute: \"{}\"".format(prefix, cmd),
236                  file=sys.stderr)
237            print("\n{} *** Error message: \"{}\"".format(prefix, foutput),
238                  file=sys.stderr)
239            print("returncode {}; expected {}".format(proc.returncode,
240                                                      exit_codes))
241            print("\n{} *** Aborting test run.".format(prefix), file=sys.stderr)
242            print("\n\n{} *** stdout ***".format(proc.stdout), file=sys.stderr)
243            print("\n\n{} *** stderr ***".format(proc.stderr), file=sys.stderr)
244            raise PluginMgrTestFail(
245                stage, output,
246                '"{}" did not complete successfully'.format(prefix))
247
248def verify_by_json(procout, res, tidx, args, pm):
249    try:
250        outputJSON = json.loads(procout)
251    except json.JSONDecodeError:
252        res.set_result(ResultState.fail)
253        res.set_failmsg('Cannot decode verify command\'s output. Is it JSON?')
254        return res
255
256    matchJSON = json.loads(json.dumps(tidx['matchJSON']))
257
258    if type(outputJSON) != type(matchJSON):
259        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {} '
260        failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
261        res.set_result(ResultState.fail)
262        res.set_failmsg(failmsg)
263        return res
264
265    if len(matchJSON) > len(outputJSON):
266        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
267        failmsg = failmsg.format(len(outputJSON), outputJSON, len(matchJSON), matchJSON)
268        res.set_result(ResultState.fail)
269        res.set_failmsg(failmsg)
270        return res
271    res = find_in_json(res, outputJSON, matchJSON, 0)
272
273    return res
274
275def find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
276    if res.get_result() == ResultState.fail:
277        return res
278
279    if type(matchJSONVal) == list:
280        res = find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey)
281
282    elif type(matchJSONVal) == dict:
283        res = find_in_json_dict(res, outputJSONVal, matchJSONVal)
284    else:
285        res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
286
287    if res.get_result() != ResultState.fail:
288        res.set_result(ResultState.success)
289        return res
290
291    return res
292
293def find_in_json_list(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
294    if (type(matchJSONVal) != type(outputJSONVal)):
295        failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
296        failmsg = failmsg.format(outputJSONVal, matchJSONVal)
297        res.set_result(ResultState.fail)
298        res.set_failmsg(failmsg)
299        return res
300
301    if len(matchJSONVal) > len(outputJSONVal):
302        failmsg = "Your matchJSON value is an array, and it contains more elements than the command under test\'s output:\ncommand output (length: {}):\n{}\nmatchJSON value (length: {}):\n{}"
303        failmsg = failmsg.format(len(outputJSONVal), outputJSONVal, len(matchJSONVal), matchJSONVal)
304        res.set_result(ResultState.fail)
305        res.set_failmsg(failmsg)
306        return res
307
308    for matchJSONIdx, matchJSONVal in enumerate(matchJSONVal):
309        res = find_in_json(res, outputJSONVal[matchJSONIdx], matchJSONVal,
310                           matchJSONKey)
311    return res
312
313def find_in_json_dict(res, outputJSONVal, matchJSONVal):
314    for matchJSONKey, matchJSONVal in matchJSONVal.items():
315        if type(outputJSONVal) == dict:
316            if matchJSONKey not in outputJSONVal:
317                failmsg = 'Key not found in json output: {}: {}\nMatching against output: {}'
318                failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal)
319                res.set_result(ResultState.fail)
320                res.set_failmsg(failmsg)
321                return res
322
323        else:
324            failmsg = 'Original output and matchJSON value are not the same type: output: {} != matchJSON: {}'
325            failmsg = failmsg.format(type(outputJSON).__name__, type(matchJSON).__name__)
326            res.set_result(ResultState.fail)
327            res.set_failmsg(failmsg)
328            return rest
329
330        if type(outputJSONVal) == dict and (type(outputJSONVal[matchJSONKey]) == dict or
331                type(outputJSONVal[matchJSONKey]) == list):
332            if len(matchJSONVal) > 0:
333                res = find_in_json(res, outputJSONVal[matchJSONKey], matchJSONVal, matchJSONKey)
334            # handling corner case where matchJSONVal == [] or matchJSONVal == {}
335            else:
336                res = find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey)
337        else:
338            res = find_in_json(res, outputJSONVal, matchJSONVal, matchJSONKey)
339    return res
340
341def find_in_json_other(res, outputJSONVal, matchJSONVal, matchJSONKey=None):
342    if matchJSONKey in outputJSONVal:
343        if matchJSONVal != outputJSONVal[matchJSONKey]:
344            failmsg = 'Value doesn\'t match: {}: {} != {}\nMatching against output: {}'
345            failmsg = failmsg.format(matchJSONKey, matchJSONVal, outputJSONVal[matchJSONKey], outputJSONVal)
346            res.set_result(ResultState.fail)
347            res.set_failmsg(failmsg)
348            return res
349
350    return res
351
352def run_one_test(pm, args, index, tidx):
353    global NAMES
354    result = True
355    tresult = ""
356    tap = ""
357    res = TestResult(tidx['id'], tidx['name'])
358    if args.verbose > 0:
359        print("\t====================\n=====> ", end="")
360    print("Test " + tidx["id"] + ": " + tidx["name"])
361
362    if 'skip' in tidx:
363        if tidx['skip'] == 'yes':
364            res = TestResult(tidx['id'], tidx['name'])
365            res.set_result(ResultState.skip)
366            res.set_errormsg('Test case designated as skipped.')
367            pm.call_pre_case(tidx, test_skip=True)
368            pm.call_post_execute()
369            return res
370
371    if 'dependsOn' in tidx:
372        if (args.verbose > 0):
373            print('probe command for test skip')
374        (p, procout) = exec_cmd(args, pm, 'execute', tidx['dependsOn'])
375        if p:
376            if (p.returncode != 0):
377                res = TestResult(tidx['id'], tidx['name'])
378                res.set_result(ResultState.skip)
379                res.set_errormsg('probe command: test skipped.')
380                pm.call_pre_case(tidx, test_skip=True)
381                pm.call_post_execute()
382                return res
383
384    # populate NAMES with TESTID for this test
385    NAMES['TESTID'] = tidx['id']
386
387    pm.call_pre_case(tidx)
388    prepare_env(args, pm, 'setup', "-----> prepare stage", tidx["setup"])
389
390    if (args.verbose > 0):
391        print('-----> execute stage')
392    pm.call_pre_execute()
393    (p, procout) = exec_cmd(args, pm, 'execute', tidx["cmdUnderTest"])
394    if p:
395        exit_code = p.returncode
396    else:
397        exit_code = None
398
399    pm.call_post_execute()
400
401    if (exit_code is None or exit_code != int(tidx["expExitCode"])):
402        print("exit: {!r}".format(exit_code))
403        print("exit: {}".format(int(tidx["expExitCode"])))
404        #print("exit: {!r} {}".format(exit_code, int(tidx["expExitCode"])))
405        res.set_result(ResultState.fail)
406        res.set_failmsg('Command exited with {}, expected {}\n{}'.format(exit_code, tidx["expExitCode"], procout))
407        print(procout)
408    else:
409        if args.verbose > 0:
410            print('-----> verify stage')
411        (p, procout) = exec_cmd(args, pm, 'verify', tidx["verifyCmd"])
412        if procout:
413            if 'matchJSON' in tidx:
414                verify_by_json(procout, res, tidx, args, pm)
415            elif 'matchPattern' in tidx:
416                match_pattern = re.compile(
417                    str(tidx["matchPattern"]), re.DOTALL | re.MULTILINE)
418                match_index = re.findall(match_pattern, procout)
419                if len(match_index) != int(tidx["matchCount"]):
420                    res.set_result(ResultState.fail)
421                    res.set_failmsg('Could not match regex pattern. Verify command output:\n{}'.format(procout))
422                else:
423                    res.set_result(ResultState.success)
424            else:
425                res.set_result(ResultState.fail)
426                res.set_failmsg('Must specify a match option: matchJSON or matchPattern\n{}'.format(procout))
427        elif int(tidx["matchCount"]) != 0:
428            res.set_result(ResultState.fail)
429            res.set_failmsg('No output generated by verify command.')
430        else:
431            res.set_result(ResultState.success)
432
433    prepare_env(args, pm, 'teardown', '-----> teardown stage', tidx['teardown'], procout)
434    pm.call_post_case()
435
436    index += 1
437
438    # remove TESTID from NAMES
439    del(NAMES['TESTID'])
440    return res
441
442def test_runner(pm, args, filtered_tests):
443    """
444    Driver function for the unit tests.
445
446    Prints information about the tests being run, executes the setup and
447    teardown commands and the command under test itself. Also determines
448    success/failure based on the information in the test case and generates
449    TAP output accordingly.
450    """
451    testlist = filtered_tests
452    tcount = len(testlist)
453    index = 1
454    tap = ''
455    badtest = None
456    stage = None
457    emergency_exit = False
458    emergency_exit_message = ''
459
460    tsr = TestSuiteReport()
461
462    try:
463        pm.call_pre_suite(tcount, [tidx['id'] for tidx in testlist])
464    except Exception as ee:
465        ex_type, ex, ex_tb = sys.exc_info()
466        print('Exception {} {} (caught in pre_suite).'.
467              format(ex_type, ex))
468        traceback.print_tb(ex_tb)
469        emergency_exit_message = 'EMERGENCY EXIT, call_pre_suite failed with exception {} {}\n'.format(ex_type, ex)
470        emergency_exit = True
471        stage = 'pre-SUITE'
472
473    if emergency_exit:
474        pm.call_post_suite(index)
475        return emergency_exit_message
476    if args.verbose > 1:
477        print('give test rig 2 seconds to stabilize')
478    time.sleep(2)
479    for tidx in testlist:
480        if "flower" in tidx["category"] and args.device == None:
481            errmsg = "Tests using the DEV2 variable must define the name of a "
482            errmsg += "physical NIC with the -d option when running tdc.\n"
483            errmsg += "Test has been skipped."
484            if args.verbose > 1:
485                print(errmsg)
486            res = TestResult(tidx['id'], tidx['name'])
487            res.set_result(ResultState.skip)
488            res.set_errormsg(errmsg)
489            tsr.add_resultdata(res)
490            index += 1
491            continue
492        try:
493            badtest = tidx  # in case it goes bad
494            res = run_one_test(pm, args, index, tidx)
495            tsr.add_resultdata(res)
496        except PluginMgrTestFail as pmtf:
497            ex_type, ex, ex_tb = sys.exc_info()
498            stage = pmtf.stage
499            message = pmtf.message
500            output = pmtf.output
501            res = TestResult(tidx['id'], tidx['name'])
502            res.set_result(ResultState.skip)
503            res.set_errormsg(pmtf.message)
504            res.set_failmsg(pmtf.output)
505            tsr.add_resultdata(res)
506            index += 1
507            print(message)
508            print('Exception {} {} (caught in test_runner, running test {} {} {} stage {})'.
509                  format(ex_type, ex, index, tidx['id'], tidx['name'], stage))
510            print('---------------')
511            print('traceback')
512            traceback.print_tb(ex_tb)
513            print('---------------')
514            if stage == 'teardown':
515                print('accumulated output for this test:')
516                if pmtf.output:
517                    print(pmtf.output)
518            print('---------------')
519            break
520        index += 1
521
522    # if we failed in setup or teardown,
523    # fill in the remaining tests with ok-skipped
524    count = index
525
526    if tcount + 1 != count:
527        for tidx in testlist[count - 1:]:
528            res = TestResult(tidx['id'], tidx['name'])
529            res.set_result(ResultState.skip)
530            msg = 'skipped - previous {} failed {} {}'.format(stage,
531                index, badtest.get('id', '--Unknown--'))
532            res.set_errormsg(msg)
533            tsr.add_resultdata(res)
534            count += 1
535
536    if args.pause:
537        print('Want to pause\nPress enter to continue ...')
538        if input(sys.stdin):
539            print('got something on stdin')
540
541    pm.call_post_suite(index)
542
543    return tsr
544
545def has_blank_ids(idlist):
546    """
547    Search the list for empty ID fields and return true/false accordingly.
548    """
549    return not(all(k for k in idlist))
550
551
552def load_from_file(filename):
553    """
554    Open the JSON file containing the test cases and return them
555    as list of ordered dictionary objects.
556    """
557    try:
558        with open(filename) as test_data:
559            testlist = json.load(test_data, object_pairs_hook=OrderedDict)
560    except json.JSONDecodeError as jde:
561        print('IGNORING test case file {}\n\tBECAUSE:  {}'.format(filename, jde))
562        testlist = list()
563    else:
564        idlist = get_id_list(testlist)
565        if (has_blank_ids(idlist)):
566            for k in testlist:
567                k['filename'] = filename
568    return testlist
569
570
571def args_parse():
572    """
573    Create the argument parser.
574    """
575    parser = argparse.ArgumentParser(description='Linux TC unit tests')
576    return parser
577
578
579def set_args(parser):
580    """
581    Set the command line arguments for tdc.
582    """
583    parser.add_argument(
584        '--outfile', type=str,
585        help='Path to the file in which results should be saved. ' +
586        'Default target is the current directory.')
587    parser.add_argument(
588        '-p', '--path', type=str,
589        help='The full path to the tc executable to use')
590    sg = parser.add_argument_group(
591        'selection', 'select which test cases: ' +
592        'files plus directories; filtered by categories plus testids')
593    ag = parser.add_argument_group(
594        'action', 'select action to perform on selected test cases')
595
596    sg.add_argument(
597        '-D', '--directory', nargs='+', metavar='DIR',
598        help='Collect tests from the specified directory(ies) ' +
599        '(default [tc-tests])')
600    sg.add_argument(
601        '-f', '--file', nargs='+', metavar='FILE',
602        help='Run tests from the specified file(s)')
603    sg.add_argument(
604        '-c', '--category', nargs='*', metavar='CATG', default=['+c'],
605        help='Run tests only from the specified category/ies, ' +
606        'or if no category/ies is/are specified, list known categories.')
607    sg.add_argument(
608        '-e', '--execute', nargs='+', metavar='ID',
609        help='Execute the specified test cases with specified IDs')
610    ag.add_argument(
611        '-l', '--list', action='store_true',
612        help='List all test cases, or those only within the specified category')
613    ag.add_argument(
614        '-s', '--show', action='store_true', dest='showID',
615        help='Display the selected test cases')
616    ag.add_argument(
617        '-i', '--id', action='store_true', dest='gen_id',
618        help='Generate ID numbers for new test cases')
619    parser.add_argument(
620        '-v', '--verbose', action='count', default=0,
621        help='Show the commands that are being run')
622    parser.add_argument(
623        '--format', default='tap', const='tap', nargs='?',
624        choices=['none', 'xunit', 'tap'],
625        help='Specify the format for test results. (Default: TAP)')
626    parser.add_argument('-d', '--device',
627                        help='Execute test cases that use a physical device, ' +
628                        'where DEVICE is its name. (If not defined, tests ' +
629                        'that require a physical device will be skipped)')
630    parser.add_argument(
631        '-P', '--pause', action='store_true',
632        help='Pause execution just before post-suite stage')
633    return parser
634
635
636def check_default_settings(args, remaining, pm):
637    """
638    Process any arguments overriding the default settings,
639    and ensure the settings are correct.
640    """
641    # Allow for overriding specific settings
642    global NAMES
643
644    if args.path != None:
645        NAMES['TC'] = args.path
646    if args.device != None:
647        NAMES['DEV2'] = args.device
648    if 'TIMEOUT' not in NAMES:
649        NAMES['TIMEOUT'] = None
650    if not os.path.isfile(NAMES['TC']):
651        print("The specified tc path " + NAMES['TC'] + " does not exist.")
652        exit(1)
653
654    pm.call_check_args(args, remaining)
655
656
657def get_id_list(alltests):
658    """
659    Generate a list of all IDs in the test cases.
660    """
661    return [x["id"] for x in alltests]
662
663
664def check_case_id(alltests):
665    """
666    Check for duplicate test case IDs.
667    """
668    idl = get_id_list(alltests)
669    return [x for x in idl if idl.count(x) > 1]
670
671
672def does_id_exist(alltests, newid):
673    """
674    Check if a given ID already exists in the list of test cases.
675    """
676    idl = get_id_list(alltests)
677    return (any(newid == x for x in idl))
678
679
680def generate_case_ids(alltests):
681    """
682    If a test case has a blank ID field, generate a random hex ID for it
683    and then write the test cases back to disk.
684    """
685    import random
686    for c in alltests:
687        if (c["id"] == ""):
688            while True:
689                newid = str('{:04x}'.format(random.randrange(16**4)))
690                if (does_id_exist(alltests, newid)):
691                    continue
692                else:
693                    c['id'] = newid
694                    break
695
696    ufilename = []
697    for c in alltests:
698        if ('filename' in c):
699            ufilename.append(c['filename'])
700    ufilename = get_unique_item(ufilename)
701    for f in ufilename:
702        testlist = []
703        for t in alltests:
704            if 'filename' in t:
705                if t['filename'] == f:
706                    del t['filename']
707                    testlist.append(t)
708        outfile = open(f, "w")
709        json.dump(testlist, outfile, indent=4)
710        outfile.write("\n")
711        outfile.close()
712
713def filter_tests_by_id(args, testlist):
714    '''
715    Remove tests from testlist that are not in the named id list.
716    If id list is empty, return empty list.
717    '''
718    newlist = list()
719    if testlist and args.execute:
720        target_ids = args.execute
721
722        if isinstance(target_ids, list) and (len(target_ids) > 0):
723            newlist = list(filter(lambda x: x['id'] in target_ids, testlist))
724    return newlist
725
726def filter_tests_by_category(args, testlist):
727    '''
728    Remove tests from testlist that are not in a named category.
729    '''
730    answer = list()
731    if args.category and testlist:
732        test_ids = list()
733        for catg in set(args.category):
734            if catg == '+c':
735                continue
736            print('considering category {}'.format(catg))
737            for tc in testlist:
738                if catg in tc['category'] and tc['id'] not in test_ids:
739                    answer.append(tc)
740                    test_ids.append(tc['id'])
741
742    return answer
743
744
745def get_test_cases(args):
746    """
747    If a test case file is specified, retrieve tests from that file.
748    Otherwise, glob for all json files in subdirectories and load from
749    each one.
750    Also, if requested, filter by category, and add tests matching
751    certain ids.
752    """
753    import fnmatch
754
755    flist = []
756    testdirs = ['tc-tests']
757
758    if args.file:
759        # at least one file was specified - remove the default directory
760        testdirs = []
761
762        for ff in args.file:
763            if not os.path.isfile(ff):
764                print("IGNORING file " + ff + "\n\tBECAUSE does not exist.")
765            else:
766                flist.append(os.path.abspath(ff))
767
768    if args.directory:
769        testdirs = args.directory
770
771    for testdir in testdirs:
772        for root, dirnames, filenames in os.walk(testdir):
773            for filename in fnmatch.filter(filenames, '*.json'):
774                candidate = os.path.abspath(os.path.join(root, filename))
775                if candidate not in testdirs:
776                    flist.append(candidate)
777
778    alltestcases = list()
779    for casefile in flist:
780        alltestcases = alltestcases + (load_from_file(casefile))
781
782    allcatlist = get_test_categories(alltestcases)
783    allidlist = get_id_list(alltestcases)
784
785    testcases_by_cats = get_categorized_testlist(alltestcases, allcatlist)
786    idtestcases = filter_tests_by_id(args, alltestcases)
787    cattestcases = filter_tests_by_category(args, alltestcases)
788
789    cat_ids = [x['id'] for x in cattestcases]
790    if args.execute:
791        if args.category:
792            alltestcases = cattestcases + [x for x in idtestcases if x['id'] not in cat_ids]
793        else:
794            alltestcases = idtestcases
795    else:
796        if cat_ids:
797            alltestcases = cattestcases
798        else:
799            # just accept the existing value of alltestcases,
800            # which has been filtered by file/directory
801            pass
802
803    return allcatlist, allidlist, testcases_by_cats, alltestcases
804
805
806def set_operation_mode(pm, parser, args, remaining):
807    """
808    Load the test case data and process remaining arguments to determine
809    what the script should do for this run, and call the appropriate
810    function.
811    """
812    ucat, idlist, testcases, alltests = get_test_cases(args)
813
814    if args.gen_id:
815        if (has_blank_ids(idlist)):
816            alltests = generate_case_ids(alltests)
817        else:
818            print("No empty ID fields found in test files.")
819        exit(0)
820
821    duplicate_ids = check_case_id(alltests)
822    if (len(duplicate_ids) > 0):
823        print("The following test case IDs are not unique:")
824        print(str(set(duplicate_ids)))
825        print("Please correct them before continuing.")
826        exit(1)
827
828    if args.showID:
829        for atest in alltests:
830            print_test_case(atest)
831        exit(0)
832
833    if isinstance(args.category, list) and (len(args.category) == 0):
834        print("Available categories:")
835        print_sll(ucat)
836        exit(0)
837
838    if args.list:
839        list_test_cases(alltests)
840        exit(0)
841
842    exit_code = 0 # KSFT_PASS
843    if len(alltests):
844        req_plugins = pm.get_required_plugins(alltests)
845        try:
846            args = pm.load_required_plugins(req_plugins, parser, args, remaining)
847        except PluginDependencyException as pde:
848            print('The following plugins were not found:')
849            print('{}'.format(pde.missing_pg))
850        catresults = test_runner(pm, args, alltests)
851        if catresults.count_failures() != 0:
852            exit_code = 1 # KSFT_FAIL
853        if args.format == 'none':
854            print('Test results output suppression requested\n')
855        else:
856            print('\nAll test results: \n')
857            if args.format == 'xunit':
858                suffix = 'xml'
859                res = catresults.format_xunit()
860            elif args.format == 'tap':
861                suffix = 'tap'
862                res = catresults.format_tap()
863            print(res)
864            print('\n\n')
865            if not args.outfile:
866                fname = 'test-results.{}'.format(suffix)
867            else:
868                fname = args.outfile
869            with open(fname, 'w') as fh:
870                fh.write(res)
871                fh.close()
872                if os.getenv('SUDO_UID') is not None:
873                    os.chown(fname, uid=int(os.getenv('SUDO_UID')),
874                        gid=int(os.getenv('SUDO_GID')))
875    else:
876        print('No tests found\n')
877        exit_code = 4 # KSFT_SKIP
878    exit(exit_code)
879
880def main():
881    """
882    Start of execution; set up argument parser and get the arguments,
883    and start operations.
884    """
885    parser = args_parse()
886    parser = set_args(parser)
887    pm = PluginMgr(parser)
888    parser = pm.call_add_args(parser)
889    (args, remaining) = parser.parse_known_args()
890    args.NAMES = NAMES
891    pm.set_args(args)
892    check_default_settings(args, remaining, pm)
893    if args.verbose > 2:
894        print('args is {}'.format(args))
895
896    set_operation_mode(pm, parser, args, remaining)
897
898if __name__ == "__main__":
899    main()
900