1#!/usr/bin/env python3 2import argparse 3import os 4import re 5from typing import Any, Dict, List, Tuple 6 7import yaml 8 9 10def main() -> None: 11 parser = argparse.ArgumentParser() 12 13 parser.add_argument("--repo", help="Path to the repository", default=".") 14 parser.add_argument( 15 "--commit", 16 help="Commit the changes", 17 default=False, 18 action="store_true", 19 ) 20 21 subparsers = parser.add_subparsers() 22 subparsers.required = True 23 24 parser_merge = subparsers.add_parser( 25 "merge", help="Merge a reference clang-tidy config" 26 ) 27 parser_merge.add_argument( 28 "--reference", help="Path to reference clang-tidy", required=True 29 ) 30 parser_merge.set_defaults(func=subcmd_merge) 31 32 parser_enable = subparsers.add_parser( 33 "enable", help="Enable a rule in a reference clang-tidy config" 34 ) 35 parser_enable.add_argument("check", help="Check to enable") 36 parser_enable.set_defaults(func=subcmd_enable) 37 38 parser_disable = subparsers.add_parser( 39 "disable", help="Enable a rule in a reference clang-tidy config" 40 ) 41 parser_disable.add_argument("check", help="Check to disable") 42 parser_disable.add_argument( 43 "--drop", help="Delete the check from the config", action="store_true" 44 ) 45 parser_disable.set_defaults(func=subcmd_disable) 46 47 args = parser.parse_args() 48 args.func(args) 49 50 51def subcmd_merge(args: argparse.Namespace) -> None: 52 repo_path, repo_config = load_config(args.repo) 53 ref_path, ref_config = load_config(args.reference) 54 55 result = {} 56 57 all_keys_set = set(repo_config.keys()) | set(ref_config.keys()) 58 special_keys = ["Checks", "CheckOptions"] 59 60 # Create ordered_keys: special keys first (if present, in their defined order), 61 # followed by the rest of the keys sorted alphabetically. 62 ordered_keys = [k for k in special_keys if k in all_keys_set] + sorted( 63 list(all_keys_set - set(special_keys)) 64 ) 65 66 for key in ordered_keys: 67 repo_value = repo_config.get(key) 68 ref_value = ref_config.get(key) 69 70 key_class = globals().get(f"Key_{key}") 71 if key_class and hasattr(key_class, "merge"): 72 result[key] = key_class.merge(repo_value, ref_value) 73 elif repo_value: 74 result[key] = repo_value 75 else: 76 result[key] = ref_value 77 78 with open(repo_path, "w") as f: 79 f.write(format_yaml_output(result)) 80 81 82def subcmd_enable(args: argparse.Namespace) -> None: 83 repo_path, repo_config = load_config(args.repo) 84 85 if "Checks" in repo_config: 86 repo_config["Checks"] = Key_Checks.enable( 87 repo_config["Checks"], args.check 88 ) 89 90 with open(repo_path, "w") as f: 91 f.write(format_yaml_output(repo_config)) 92 93 pass 94 95 96def subcmd_disable(args: argparse.Namespace) -> None: 97 repo_path, repo_config = load_config(args.repo) 98 99 if "Checks" in repo_config: 100 repo_config["Checks"] = Key_Checks.disable( 101 repo_config["Checks"], args.check, args.drop 102 ) 103 104 with open(repo_path, "w") as f: 105 f.write(format_yaml_output(repo_config)) 106 107 pass 108 109 110class Key_Checks: 111 @staticmethod 112 def merge(repo: str, ref: str) -> str: 113 repo_checks = Key_Checks._split(repo) 114 ref_checks = Key_Checks._split(ref) 115 116 result: Dict[str, bool] = {} 117 118 for k, v in repo_checks.items(): 119 result[k] = v 120 for k, v in ref_checks.items(): 121 if k not in result: 122 result[k] = False 123 124 return Key_Checks._join(result) 125 126 @staticmethod 127 def enable(repo: str, check: str) -> str: 128 repo_checks = Key_Checks._split(repo) 129 repo_checks[check] = True 130 return Key_Checks._join(repo_checks) 131 132 @staticmethod 133 def disable(repo: str, check: str, drop: bool) -> str: 134 repo_checks = Key_Checks._split(repo) 135 if drop: 136 repo_checks.pop(check, None) 137 else: 138 repo_checks[check] = False 139 return Key_Checks._join(repo_checks) 140 141 @staticmethod 142 def _split(s: str) -> Dict[str, bool]: 143 result: Dict[str, bool] = {} 144 if not s: 145 return result 146 for item in s.split(): 147 item = item.replace(",", "") 148 # Ignore global wildcard because we handle that specifically. 149 if item.startswith("-*"): 150 continue 151 # Drop category wildcard disables since we already use a global wildcard. 152 if item.startswith("-") and "*" in item: 153 continue 154 if item.startswith("-"): 155 result[item[1:]] = False 156 else: 157 result[item] = True 158 return result 159 160 @staticmethod 161 def _join(data: Dict[str, bool]) -> str: 162 return ( 163 ",\n".join( 164 ["-*"] + [k if v else f"-{k}" for k, v in sorted(data.items())] 165 ) 166 + "\n" 167 ) 168 169 170class Key_CheckOptions: 171 @staticmethod 172 def merge( 173 repo: List[Dict[str, str]], ref: List[Dict[str, str]] 174 ) -> List[Dict[str, str]]: 175 unrolled_repo: Dict[str, str] = {} 176 for item in repo or []: 177 unrolled_repo[item["key"]] = item["value"] 178 for item in ref or []: 179 if item["key"] in unrolled_repo: 180 continue 181 unrolled_repo[item["key"]] = item["value"] 182 183 return [ 184 {"key": k, "value": v} for k, v in sorted(unrolled_repo.items()) 185 ] 186 187 188def load_config(path: str) -> Tuple[str, Dict[str, Any]]: 189 if "clang-tidy" not in path: 190 path = os.path.join(path, ".clang-tidy") 191 192 if not os.path.exists(path): 193 return (path, {}) 194 195 with open(path, "r") as f: 196 data = "\n".join([x for x in f.readlines() if not x.startswith("#")]) 197 return (path, yaml.safe_load(data)) 198 199 200def format_yaml_output(data: Dict[str, Any]) -> str: 201 """Convert to a prettier YAML string: 202 - filter out excess empty lines 203 - insert new lines between keys 204 """ 205 yaml_string = yaml.dump(data, sort_keys=False, indent=4) 206 lines: List[str] = [] 207 for line in yaml_string.split("\n"): 208 # Strip excess new lines. 209 if not line: 210 continue 211 # Add new line between keys. 212 if len(lines) and re.match("[a-zA-Z0-9]+:", line): 213 lines.append("") 214 lines.append(line) 215 lines.append("") 216 217 return "\n".join(lines) 218 219 220if __name__ == "__main__": 221 main() 222