xref: /openbmc/openbmc-build-scripts/tools/config-clang-tidy (revision 037f933db93a5d8400266296704bd2eac6d69d4d)
1#!/usr/bin/env python3
2import argparse
3import os
4import re
5from typing import Any, Dict, List, Tuple
6
7import yaml
8
9
10def main() -> None:
11    parser = argparse.ArgumentParser()
12
13    parser.add_argument("--repo", help="Path to the repository", default=".")
14    parser.add_argument(
15        "--commit",
16        help="Commit the changes",
17        default=False,
18        action="store_true",
19    )
20
21    subparsers = parser.add_subparsers()
22    subparsers.required = True
23
24    parser_merge = subparsers.add_parser(
25        "merge", help="Merge a reference clang-tidy config"
26    )
27    parser_merge.add_argument(
28        "--reference", help="Path to reference clang-tidy", required=True
29    )
30    parser_merge.set_defaults(func=subcmd_merge)
31
32    parser_format = subparsers.add_parser(
33        "format", help="Format a clang-tidy config"
34    )
35    parser_format.set_defaults(func=subcmd_merge)
36
37    parser_enable = subparsers.add_parser(
38        "enable", help="Enable a rule in a reference clang-tidy config"
39    )
40    parser_enable.add_argument("check", help="Check to enable")
41    parser_enable.set_defaults(func=subcmd_enable)
42
43    parser_disable = subparsers.add_parser(
44        "disable", help="Enable a rule in a reference clang-tidy config"
45    )
46    parser_disable.add_argument("check", help="Check to disable")
47    parser_disable.add_argument(
48        "--drop", help="Delete the check from the config", action="store_true"
49    )
50    parser_disable.set_defaults(func=subcmd_disable)
51
52    args = parser.parse_args()
53    args.func(args)
54
55
56def subcmd_merge(args: argparse.Namespace) -> None:
57    repo_path, repo_config = load_config(args.repo)
58    _, ref_config = (
59        load_config(args.reference) if "reference" in args else ("", {})
60    )
61
62    result = {}
63
64    all_keys_set = set(repo_config.keys()) | set(ref_config.keys())
65    special_keys = ["Checks", "CheckOptions"]
66
67    # Create ordered_keys: special keys first (if present, in their defined order),
68    # followed by the rest of the keys sorted alphabetically.
69    ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted(
70        list(all_keys_set - set(special_keys))
71    )
72
73    for key in ordered_keys:
74        repo_value = repo_config.get(key)
75        ref_value = ref_config.get(key)
76
77        key_class = globals().get(f"Key_{key}")
78        if key_class and hasattr(key_class, "merge"):
79            result[key] = key_class.merge(repo_value, ref_value)
80        elif repo_value:
81            result[key] = repo_value
82        else:
83            result[key] = ref_value
84
85    with open(repo_path, "w") as f:
86        f.write(format_yaml_output(result))
87
88
89def subcmd_enable(args: argparse.Namespace) -> None:
90    repo_path, repo_config = load_config(args.repo)
91
92    if "Checks" in repo_config:
93        repo_config["Checks"] = Key_Checks.enable(
94            repo_config["Checks"], args.check
95        )
96
97    with open(repo_path, "w") as f:
98        f.write(format_yaml_output(repo_config))
99
100    pass
101
102
103def subcmd_disable(args: argparse.Namespace) -> None:
104    repo_path, repo_config = load_config(args.repo)
105
106    if "Checks" in repo_config:
107        repo_config["Checks"] = Key_Checks.disable(
108            repo_config["Checks"], args.check, args.drop
109        )
110
111    if "CheckOptions" in repo_config:
112        repo_config["CheckOptions"] = Key_CheckOptions.disable(
113            repo_config["CheckOptions"], args.check, args.drop
114        )
115
116    with open(repo_path, "w") as f:
117        f.write(format_yaml_output(repo_config))
118
119    pass
120
121
122class Key_Checks:
123    @staticmethod
124    def merge(repo: str, ref: str) -> str:
125        repo_checks = Key_Checks._split(repo)
126        ref_checks = Key_Checks._split(ref)
127
128        result: Dict[str, bool] = {}
129
130        for k, v in repo_checks.items():
131            result[k] = v
132        for k, v in ref_checks.items():
133            if k not in result:
134                result[k] = False
135
136        return Key_Checks._join(result)
137
138    @staticmethod
139    def enable(repo: str, check: str) -> str:
140        repo_checks = Key_Checks._split(repo)
141        repo_checks[check] = True
142        return Key_Checks._join(repo_checks)
143
144    @staticmethod
145    def disable(repo: str, check: str, drop: bool) -> str:
146        repo_checks = Key_Checks._split(repo)
147        if drop:
148            repo_checks.pop(check, None)
149        else:
150            repo_checks[check] = False
151        return Key_Checks._join(repo_checks)
152
153    @staticmethod
154    def _split(s: str) -> Dict[str, bool]:
155        result: Dict[str, bool] = {}
156        if not s:
157            return result
158        for item in s.split():
159            item = item.replace(",", "")
160            # Ignore global wildcard because we handle that specifically.
161            if item.startswith("-*"):
162                continue
163            # Drop category wildcard disables since we already use a global wildcard.
164            if item.startswith("-") and "*" in item:
165                continue
166            if item.startswith("-"):
167                result[item[1:]] = False
168            else:
169                result[item] = True
170        return result
171
172    @staticmethod
173    def _join(data: Dict[str, bool]) -> str:
174        return (
175            ",\n".join(
176                ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())]
177            )
178            + "\n"
179        )
180
181
182class Key_CheckOptions:
183    @staticmethod
184    def merge(
185        repo: List[Dict[str, str]], ref: List[Dict[str, str]]
186    ) -> List[Dict[str, str]]:
187        unrolled_repo = Key_CheckOptions._unroll(repo)
188        for item in ref or []:
189            if item["key"] in unrolled_repo:
190                continue
191            unrolled_repo[item["key"]] = item["value"]
192
193        return Key_CheckOptions._roll(unrolled_repo)
194
195    @staticmethod
196    def disable(
197        repo: List[Dict[str, str]], option: str, drop: bool
198    ) -> List[Dict[str, str]]:
199        if not drop:
200            return repo
201
202        unrolled_repo = Key_CheckOptions._unroll(repo)
203
204        if option in unrolled_repo:
205            unrolled_repo.pop(option, None)
206
207        return Key_CheckOptions._roll(unrolled_repo)
208
209    @staticmethod
210    def _unroll(repo: List[Dict[str, str]]) -> Dict[str, str]:
211        unrolled_repo: Dict[str, str] = {}
212        for item in repo or []:
213            unrolled_repo[item["key"]] = item["value"]
214        return unrolled_repo
215
216    @staticmethod
217    def _roll(data: Dict[str, str]) -> List[Dict[str, str]]:
218        return [{"key": k, "value": v} for k, v in sorted(data.items())]
219
220
221def load_config(path: str) -> Tuple[str, Dict[str, Any]]:
222    if "clang-tidy" not in path:
223        path = os.path.join(path, ".clang-tidy")
224
225    if not os.path.exists(path):
226        return (path, {})
227
228    with open(path, "r") as f:
229        data = "\n".join([x for x in f.readlines() if not x.startswith("#")])
230        return (path, yaml.safe_load(data))
231
232
233def format_yaml_output(data: Dict[str, Any]) -> str:
234    """Convert to a prettier YAML string:
235    - filter out excess empty lines
236    - insert new lines between keys
237    """
238    yaml_string = yaml.dump(data, sort_keys=False, indent=4)
239    lines: List[str] = []
240    for line in yaml_string.split("\n"):
241        # Strip excess new lines.
242        if not line:
243            continue
244        # Add new line between keys.
245        if len(lines) and re.match("[a-zA-Z0-9]+:", line):
246            lines.append("")
247        lines.append(line)
248    lines.append("")
249
250    return "\n".join(lines)
251
252
253if __name__ == "__main__":
254    main()
255