1#!/usr/bin/env python3 2import argparse 3import os 4import re 5from typing import Any, Dict, List, Tuple 6 7import yaml 8 9 10def main() -> None: 11 parser = argparse.ArgumentParser() 12 13 parser.add_argument("--repo", help="Path to the repository", default=".") 14 parser.add_argument( 15 "--commit", 16 help="Commit the changes", 17 default=False, 18 action="store_true", 19 ) 20 21 subparsers = parser.add_subparsers() 22 subparsers.required = True 23 24 parser_merge = subparsers.add_parser( 25 "merge", help="Merge a reference clang-tidy config" 26 ) 27 parser_merge.add_argument( 28 "--reference", help="Path to reference clang-tidy", required=True 29 ) 30 parser_merge.set_defaults(func=subcmd_merge) 31 32 parser_format = subparsers.add_parser( 33 "format", help="Format a clang-tidy config" 34 ) 35 parser_format.set_defaults(func=subcmd_merge) 36 37 parser_enable = subparsers.add_parser( 38 "enable", help="Enable a rule in a reference clang-tidy config" 39 ) 40 parser_enable.add_argument("check", help="Check to enable") 41 parser_enable.set_defaults(func=subcmd_enable) 42 43 parser_disable = subparsers.add_parser( 44 "disable", help="Enable a rule in a reference clang-tidy config" 45 ) 46 parser_disable.add_argument("check", help="Check to disable") 47 parser_disable.add_argument( 48 "--drop", help="Delete the check from the config", action="store_true" 49 ) 50 parser_disable.set_defaults(func=subcmd_disable) 51 52 args = parser.parse_args() 53 args.func(args) 54 55 56def subcmd_merge(args: argparse.Namespace) -> None: 57 repo_path, repo_config = load_config(args.repo) 58 _, ref_config = ( 59 load_config(args.reference) if "reference" in args else ("", {}) 60 ) 61 62 result = {} 63 64 all_keys_set = set(repo_config.keys()) | set(ref_config.keys()) 65 special_keys = ["Checks", "CheckOptions"] 66 67 # Create ordered_keys: special keys first (if present, in their defined order), 68 # followed by the rest of the keys sorted alphabetically. 69 ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted( 70 list(all_keys_set - set(special_keys)) 71 ) 72 73 for key in ordered_keys: 74 repo_value = repo_config.get(key) 75 ref_value = ref_config.get(key) 76 77 key_class = globals().get(f"Key_{key}") 78 if key_class and hasattr(key_class, "merge"): 79 result[key] = key_class.merge(repo_value, ref_value) 80 elif repo_value: 81 result[key] = repo_value 82 else: 83 result[key] = ref_value 84 85 with open(repo_path, "w") as f: 86 f.write(format_yaml_output(result)) 87 88 89def subcmd_enable(args: argparse.Namespace) -> None: 90 repo_path, repo_config = load_config(args.repo) 91 92 if "Checks" in repo_config: 93 repo_config["Checks"] = Key_Checks.enable( 94 repo_config["Checks"], args.check 95 ) 96 97 with open(repo_path, "w") as f: 98 f.write(format_yaml_output(repo_config)) 99 100 pass 101 102 103def subcmd_disable(args: argparse.Namespace) -> None: 104 repo_path, repo_config = load_config(args.repo) 105 106 if "Checks" in repo_config: 107 repo_config["Checks"] = Key_Checks.disable( 108 repo_config["Checks"], args.check, args.drop 109 ) 110 111 with open(repo_path, "w") as f: 112 f.write(format_yaml_output(repo_config)) 113 114 pass 115 116 117class Key_Checks: 118 @staticmethod 119 def merge(repo: str, ref: str) -> str: 120 repo_checks = Key_Checks._split(repo) 121 ref_checks = Key_Checks._split(ref) 122 123 result: Dict[str, bool] = {} 124 125 for k, v in repo_checks.items(): 126 result[k] = v 127 for k, v in ref_checks.items(): 128 if k not in result: 129 result[k] = False 130 131 return Key_Checks._join(result) 132 133 @staticmethod 134 def enable(repo: str, check: str) -> str: 135 repo_checks = Key_Checks._split(repo) 136 repo_checks[check] = True 137 return Key_Checks._join(repo_checks) 138 139 @staticmethod 140 def disable(repo: str, check: str, drop: bool) -> str: 141 repo_checks = Key_Checks._split(repo) 142 if drop: 143 repo_checks.pop(check, None) 144 else: 145 repo_checks[check] = False 146 return Key_Checks._join(repo_checks) 147 148 @staticmethod 149 def _split(s: str) -> Dict[str, bool]: 150 result: Dict[str, bool] = {} 151 if not s: 152 return result 153 for item in s.split(): 154 item = item.replace(",", "") 155 # Ignore global wildcard because we handle that specifically. 156 if item.startswith("-*"): 157 continue 158 # Drop category wildcard disables since we already use a global wildcard. 159 if item.startswith("-") and "*" in item: 160 continue 161 if item.startswith("-"): 162 result[item[1:]] = False 163 else: 164 result[item] = True 165 return result 166 167 @staticmethod 168 def _join(data: Dict[str, bool]) -> str: 169 return ( 170 ",\n".join( 171 ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())] 172 ) 173 + "\n" 174 ) 175 176 177class Key_CheckOptions: 178 @staticmethod 179 def merge( 180 repo: List[Dict[str, str]], ref: List[Dict[str, str]] 181 ) -> List[Dict[str, str]]: 182 unrolled_repo: Dict[str, str] = {} 183 for item in repo or []: 184 unrolled_repo[item["key"]] = item["value"] 185 for item in ref or []: 186 if item["key"] in unrolled_repo: 187 continue 188 unrolled_repo[item["key"]] = item["value"] 189 190 return [ 191 {"key": k, "value": v} for k, v in sorted(unrolled_repo.items()) 192 ] 193 194 195def load_config(path: str) -> Tuple[str, Dict[str, Any]]: 196 if "clang-tidy" not in path: 197 path = os.path.join(path, ".clang-tidy") 198 199 if not os.path.exists(path): 200 return (path, {}) 201 202 with open(path, "r") as f: 203 data = "\n".join([x for x in f.readlines() if not x.startswith("#")]) 204 return (path, yaml.safe_load(data)) 205 206 207def format_yaml_output(data: Dict[str, Any]) -> str: 208 """Convert to a prettier YAML string: 209 - filter out excess empty lines 210 - insert new lines between keys 211 """ 212 yaml_string = yaml.dump(data, sort_keys=False, indent=4) 213 lines: List[str] = [] 214 for line in yaml_string.split("\n"): 215 # Strip excess new lines. 216 if not line: 217 continue 218 # Add new line between keys. 219 if len(lines) and re.match("[a-zA-Z0-9]+:", line): 220 lines.append("") 221 lines.append(line) 222 lines.append("") 223 224 return "\n".join(lines) 225 226 227if __name__ == "__main__": 228 main() 229