1#!/usr/bin/env python3 2import argparse 3import os 4import re 5from typing import Any, Dict, List, Tuple 6 7import yaml 8 9 10def main() -> None: 11 parser = argparse.ArgumentParser() 12 13 parser.add_argument("--repo", help="Path to the repository", default=".") 14 parser.add_argument( 15 "--commit", 16 help="Commit the changes", 17 default=False, 18 action="store_true", 19 ) 20 21 subparsers = parser.add_subparsers() 22 subparsers.required = True 23 24 parser_merge = subparsers.add_parser( 25 "merge", help="Merge a reference clang-tidy config" 26 ) 27 parser_merge.add_argument( 28 "--reference", help="Path to reference clang-tidy", required=True 29 ) 30 parser_merge.set_defaults(func=subcmd_merge) 31 32 parser_format = subparsers.add_parser( 33 "format", help="Format a clang-tidy config" 34 ) 35 parser_format.set_defaults(func=subcmd_merge) 36 37 parser_enable = subparsers.add_parser( 38 "enable", help="Enable a rule in a reference clang-tidy config" 39 ) 40 parser_enable.add_argument("check", help="Check to enable") 41 parser_enable.set_defaults(func=subcmd_enable) 42 43 parser_disable = subparsers.add_parser( 44 "disable", help="Enable a rule in a reference clang-tidy config" 45 ) 46 parser_disable.add_argument("check", help="Check to disable") 47 parser_disable.add_argument( 48 "--drop", help="Delete the check from the config", action="store_true" 49 ) 50 parser_disable.set_defaults(func=subcmd_disable) 51 52 args = parser.parse_args() 53 args.func(args) 54 55 56def subcmd_merge(args: argparse.Namespace) -> None: 57 repo_path, repo_config = load_config(args.repo) 58 _, ref_config = ( 59 load_config(args.reference) if "reference" in args else ("", {}) 60 ) 61 62 result = {} 63 64 all_keys_set = set(repo_config.keys()) | set(ref_config.keys()) 65 special_keys = ["Checks", "CheckOptions"] 66 67 # Create ordered_keys: special keys first (if present, in their defined order), 68 # followed by the rest of the keys sorted alphabetically. 69 ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted( 70 list(all_keys_set - set(special_keys)) 71 ) 72 73 for key in ordered_keys: 74 repo_value = repo_config.get(key) 75 ref_value = ref_config.get(key) 76 77 key_class = globals().get(f"Key_{key}") 78 if key_class and hasattr(key_class, "merge"): 79 result[key] = key_class.merge(repo_value, ref_value) 80 elif repo_value: 81 result[key] = repo_value 82 else: 83 result[key] = ref_value 84 85 with open(repo_path, "w") as f: 86 f.write(format_yaml_output(result)) 87 88 89def subcmd_enable(args: argparse.Namespace) -> None: 90 repo_path, repo_config = load_config(args.repo) 91 92 if "Checks" in repo_config: 93 repo_config["Checks"] = Key_Checks.enable( 94 repo_config["Checks"], args.check 95 ) 96 97 with open(repo_path, "w") as f: 98 f.write(format_yaml_output(repo_config)) 99 100 pass 101 102 103def subcmd_disable(args: argparse.Namespace) -> None: 104 repo_path, repo_config = load_config(args.repo) 105 106 if "Checks" in repo_config: 107 repo_config["Checks"] = Key_Checks.disable( 108 repo_config["Checks"], args.check, args.drop 109 ) 110 111 if "CheckOptions" in repo_config: 112 repo_config["CheckOptions"] = Key_CheckOptions.disable( 113 repo_config["CheckOptions"], args.check, args.drop 114 ) 115 116 with open(repo_path, "w") as f: 117 f.write(format_yaml_output(repo_config)) 118 119 pass 120 121 122class Key_Checks: 123 @staticmethod 124 def merge(repo: str, ref: str) -> str: 125 repo_checks = Key_Checks._split(repo) 126 ref_checks = Key_Checks._split(ref) 127 128 result: Dict[str, bool] = {} 129 130 for k, v in repo_checks.items(): 131 result[k] = v 132 for k, v in ref_checks.items(): 133 if k not in result: 134 result[k] = False 135 136 return Key_Checks._join(result) 137 138 @staticmethod 139 def enable(repo: str, check: str) -> str: 140 repo_checks = Key_Checks._split(repo) 141 repo_checks[check] = True 142 return Key_Checks._join(repo_checks) 143 144 @staticmethod 145 def disable(repo: str, check: str, drop: bool) -> str: 146 repo_checks = Key_Checks._split(repo) 147 if drop: 148 repo_checks.pop(check, None) 149 else: 150 repo_checks[check] = False 151 return Key_Checks._join(repo_checks) 152 153 @staticmethod 154 def _split(s: str) -> Dict[str, bool]: 155 result: Dict[str, bool] = {} 156 if not s: 157 return result 158 for item in s.split(): 159 item = item.replace(",", "") 160 # Ignore global wildcard because we handle that specifically. 161 if item.startswith("-*"): 162 continue 163 # Drop category wildcard disables since we already use a global wildcard. 164 if item.startswith("-") and "*" in item: 165 continue 166 if item.startswith("-"): 167 result[item[1:]] = False 168 else: 169 result[item] = True 170 return result 171 172 @staticmethod 173 def _join(data: Dict[str, bool]) -> str: 174 return ( 175 ",\n".join( 176 ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())] 177 ) 178 + "\n" 179 ) 180 181 182class Key_CheckOptions: 183 @staticmethod 184 def merge( 185 repo: List[Dict[str, str]], ref: List[Dict[str, str]] 186 ) -> List[Dict[str, str]]: 187 unrolled_repo = Key_CheckOptions._unroll(repo) 188 for item in ref or []: 189 if item["key"] in unrolled_repo: 190 continue 191 unrolled_repo[item["key"]] = item["value"] 192 193 return Key_CheckOptions._roll(unrolled_repo) 194 195 @staticmethod 196 def disable( 197 repo: List[Dict[str, str]], option: str, drop: bool 198 ) -> List[Dict[str, str]]: 199 if not drop: 200 return repo 201 202 unrolled_repo = Key_CheckOptions._unroll(repo) 203 204 if option in unrolled_repo: 205 unrolled_repo.pop(option, None) 206 207 return Key_CheckOptions._roll(unrolled_repo) 208 209 @staticmethod 210 def _unroll(repo: List[Dict[str, str]]) -> Dict[str, str]: 211 unrolled_repo: Dict[str, str] = {} 212 for item in repo or []: 213 unrolled_repo[item["key"]] = item["value"] 214 return unrolled_repo 215 216 @staticmethod 217 def _roll(data: Dict[str, str]) -> List[Dict[str, str]]: 218 return [{"key": k, "value": v} for k, v in sorted(data.items())] 219 220 221def load_config(path: str) -> Tuple[str, Dict[str, Any]]: 222 if "clang-tidy" not in path: 223 path = os.path.join(path, ".clang-tidy") 224 225 if not os.path.exists(path): 226 return (path, {}) 227 228 with open(path, "r") as f: 229 data = "\n".join([x for x in f.readlines() if not x.startswith("#")]) 230 return (path, yaml.safe_load(data)) 231 232 233def format_yaml_output(data: Dict[str, Any]) -> str: 234 """Convert to a prettier YAML string: 235 - filter out excess empty lines 236 - insert new lines between keys 237 """ 238 yaml_string = yaml.dump(data, sort_keys=False, indent=4) 239 lines: List[str] = [] 240 for line in yaml_string.split("\n"): 241 # Strip excess new lines. 242 if not line: 243 continue 244 # Add new line between keys. 245 if len(lines) and re.match("[a-zA-Z0-9]+:", line): 246 lines.append("") 247 lines.append(line) 248 lines.append("") 249 250 return "\n".join(lines) 251 252 253if __name__ == "__main__": 254 main() 255