xref: /openbmc/openbmc-build-scripts/tools/config-clang-tidy (revision b8ce3818c2c73196857daf9c0e6e3ee18938f6cc)
1#!/usr/bin/env python3
2import argparse
3import os
4import re
5from typing import Any, Dict, List, Tuple
6
7import yaml
8
9
10def main() -> None:
11    parser = argparse.ArgumentParser()
12
13    parser.add_argument("--repo", help="Path to the repository", default=".")
14    parser.add_argument(
15        "--commit",
16        help="Commit the changes",
17        default=False,
18        action="store_true",
19    )
20
21    subparsers = parser.add_subparsers()
22    subparsers.required = True
23
24    parser_merge = subparsers.add_parser(
25        "merge", help="Merge a reference clang-tidy config"
26    )
27    parser_merge.add_argument(
28        "--reference", help="Path to reference clang-tidy", required=True
29    )
30    parser_merge.set_defaults(func=subcmd_merge)
31
32    parser_enable = subparsers.add_parser(
33        "enable", help="Enable a rule in a reference clang-tidy config"
34    )
35    parser_enable.add_argument("check", help="Check to enable")
36    parser_enable.set_defaults(func=subcmd_enable)
37
38    parser_disable = subparsers.add_parser(
39        "disable", help="Enable a rule in a reference clang-tidy config"
40    )
41    parser_disable.add_argument("check", help="Check to disable")
42    parser_disable.add_argument(
43        "--drop", help="Delete the check from the config", action="store_true"
44    )
45    parser_disable.set_defaults(func=subcmd_disable)
46
47    args = parser.parse_args()
48    args.func(args)
49
50
51def subcmd_merge(args: argparse.Namespace) -> None:
52    repo_path, repo_config = load_config(args.repo)
53    ref_path, ref_config = load_config(args.reference)
54
55    result = {}
56
57    all_keys_set = set(repo_config.keys()) | set(ref_config.keys())
58    special_keys = ["Checks", "CheckOptions"]
59
60    # Create ordered_keys: special keys first (if present, in their defined order),
61    # followed by the rest of the keys sorted alphabetically.
62    ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted(
63        list(all_keys_set - set(special_keys))
64    )
65
66    for key in ordered_keys:
67        repo_value = repo_config.get(key)
68        ref_value = ref_config.get(key)
69
70        key_class = globals().get(f"Key_{key}")
71        if key_class and hasattr(key_class, "merge"):
72            result[key] = key_class.merge(repo_value, ref_value)
73        elif repo_value:
74            result[key] = repo_value
75        else:
76            result[key] = ref_value
77
78    with open(repo_path, "w") as f:
79        f.write(format_yaml_output(result))
80
81
82def subcmd_enable(args: argparse.Namespace) -> None:
83    repo_path, repo_config = load_config(args.repo)
84
85    if "Checks" in repo_config:
86        repo_config["Checks"] = Key_Checks.enable(
87            repo_config["Checks"], args.check
88        )
89
90    with open(repo_path, "w") as f:
91        f.write(format_yaml_output(repo_config))
92
93    pass
94
95
96def subcmd_disable(args: argparse.Namespace) -> None:
97    repo_path, repo_config = load_config(args.repo)
98
99    if "Checks" in repo_config:
100        repo_config["Checks"] = Key_Checks.disable(
101            repo_config["Checks"], args.check, args.drop
102        )
103
104    with open(repo_path, "w") as f:
105        f.write(format_yaml_output(repo_config))
106
107    pass
108
109
110class Key_Checks:
111    @staticmethod
112    def merge(repo: str, ref: str) -> str:
113        repo_checks = Key_Checks._split(repo)
114        ref_checks = Key_Checks._split(ref)
115
116        result: Dict[str, bool] = {}
117
118        for k, v in repo_checks.items():
119            result[k] = v
120        for k, v in ref_checks.items():
121            if k not in result:
122                result[k] = False
123
124        return Key_Checks._join(result)
125
126    @staticmethod
127    def enable(repo: str, check: str) -> str:
128        repo_checks = Key_Checks._split(repo)
129        repo_checks[check] = True
130        return Key_Checks._join(repo_checks)
131
132    @staticmethod
133    def disable(repo: str, check: str, drop: bool) -> str:
134        repo_checks = Key_Checks._split(repo)
135        if drop:
136            repo_checks.pop(check, None)
137        else:
138            repo_checks[check] = False
139        return Key_Checks._join(repo_checks)
140
141    @staticmethod
142    def _split(s: str) -> Dict[str, bool]:
143        result: Dict[str, bool] = {}
144        if not s:
145            return result
146        for item in s.split():
147            item = item.replace(",", "")
148            # Ignore global wildcard because we handle that specifically.
149            if item.startswith("-*"):
150                continue
151            # Drop category wildcard disables since we already use a global wildcard.
152            if item.startswith("-") and "*" in item:
153                continue
154            if item.startswith("-"):
155                result[item[1:]] = False
156            else:
157                result[item] = True
158        return result
159
160    @staticmethod
161    def _join(data: Dict[str, bool]) -> str:
162        return (
163            ",\n".join(
164                ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())]
165            )
166            + "\n"
167        )
168
169
170class Key_CheckOptions:
171    @staticmethod
172    def merge(
173        repo: List[Dict[str, str]], ref: List[Dict[str, str]]
174    ) -> List[Dict[str, str]]:
175        unrolled_repo: Dict[str, str] = {}
176        for item in repo or []:
177            unrolled_repo[item["key"]] = item["value"]
178        for item in ref or []:
179            if item["key"] in unrolled_repo:
180                continue
181            unrolled_repo[item["key"]] = item["value"]
182
183        return [
184            {"key": k, "value": v} for k, v in sorted(unrolled_repo.items())
185        ]
186
187
188def load_config(path: str) -> Tuple[str, Dict[str, Any]]:
189    if "clang-tidy" not in path:
190        path = os.path.join(path, ".clang-tidy")
191
192    if not os.path.exists(path):
193        return (path, {})
194
195    with open(path, "r") as f:
196        data = "\n".join([x for x in f.readlines() if not x.startswith("#")])
197        return (path, yaml.safe_load(data))
198
199
200def format_yaml_output(data: Dict[str, Any]) -> str:
201    """Convert to a prettier YAML string:
202    - filter out excess empty lines
203    - insert new lines between keys
204    """
205    yaml_string = yaml.dump(data, sort_keys=False, indent=4)
206    lines: List[str] = []
207    for line in yaml_string.split("\n"):
208        # Strip excess new lines.
209        if not line:
210            continue
211        # Add new line between keys.
212        if len(lines) and re.match("[a-zA-Z0-9]+:", line):
213            lines.append("")
214        lines.append(line)
215    lines.append("")
216
217    return "\n".join(lines)
218
219
220if __name__ == "__main__":
221    main()
222