1# Copyright (C) 2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6import logging 7import socket 8import bb.asyncrpc 9import json 10from . import create_async_client 11 12 13logger = logging.getLogger("hashserv.client") 14 15 16class AsyncClient(bb.asyncrpc.AsyncClient): 17 MODE_NORMAL = 0 18 MODE_GET_STREAM = 1 19 MODE_EXIST_STREAM = 2 20 21 def __init__(self, username=None, password=None): 22 super().__init__("OEHASHEQUIV", "1.1", logger) 23 self.mode = self.MODE_NORMAL 24 self.username = username 25 self.password = password 26 self.saved_become_user = None 27 28 async def setup_connection(self): 29 await super().setup_connection() 30 self.mode = self.MODE_NORMAL 31 if self.username: 32 # Save off become user temporarily because auth() resets it 33 become = self.saved_become_user 34 await self.auth(self.username, self.password) 35 36 if become: 37 await self.become_user(become) 38 39 async def send_stream(self, mode, msg): 40 async def proc(): 41 await self._set_mode(mode) 42 await self.socket.send(msg) 43 return await self.socket.recv() 44 45 return await self._send_wrapper(proc) 46 47 async def invoke(self, *args, **kwargs): 48 # It's OK if connection errors cause a failure here, because the mode 49 # is also reset to normal on a new connection 50 await self._set_mode(self.MODE_NORMAL) 51 return await super().invoke(*args, **kwargs) 52 53 async def _set_mode(self, new_mode): 54 async def stream_to_normal(): 55 await self.socket.send("END") 56 return await self.socket.recv() 57 58 async def normal_to_stream(command): 59 r = await self.invoke({command: None}) 60 if r != "ok": 61 raise ConnectionError( 62 f"Unable to transition to stream mode: Bad response from server {r!r}" 63 ) 64 65 self.logger.debug("Mode is now %s", command) 66 67 if new_mode == self.mode: 68 return 69 70 self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode) 71 72 # Always transition to normal mode before switching to any other mode 73 if self.mode != self.MODE_NORMAL: 74 r = await self._send_wrapper(stream_to_normal) 75 if r != "ok": 76 self.check_invoke_error(r) 77 raise ConnectionError( 78 f"Unable to transition to normal mode: Bad response from server {r!r}" 79 ) 80 self.logger.debug("Mode is now normal") 81 82 if new_mode == self.MODE_GET_STREAM: 83 await normal_to_stream("get-stream") 84 elif new_mode == self.MODE_EXIST_STREAM: 85 await normal_to_stream("exists-stream") 86 elif new_mode != self.MODE_NORMAL: 87 raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}") 88 89 self.mode = new_mode 90 91 async def get_unihash(self, method, taskhash): 92 r = await self.send_stream(self.MODE_GET_STREAM, "%s %s" % (method, taskhash)) 93 if not r: 94 return None 95 return r 96 97 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): 98 m = extra.copy() 99 m["taskhash"] = taskhash 100 m["method"] = method 101 m["outhash"] = outhash 102 m["unihash"] = unihash 103 return await self.invoke({"report": m}) 104 105 async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): 106 m = extra.copy() 107 m["taskhash"] = taskhash 108 m["method"] = method 109 m["unihash"] = unihash 110 return await self.invoke({"report-equiv": m}) 111 112 async def get_taskhash(self, method, taskhash, all_properties=False): 113 return await self.invoke( 114 {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} 115 ) 116 117 async def unihash_exists(self, unihash): 118 r = await self.send_stream(self.MODE_EXIST_STREAM, unihash) 119 return r == "true" 120 121 async def get_outhash(self, method, outhash, taskhash, with_unihash=True): 122 return await self.invoke( 123 { 124 "get-outhash": { 125 "outhash": outhash, 126 "taskhash": taskhash, 127 "method": method, 128 "with_unihash": with_unihash, 129 } 130 } 131 ) 132 133 async def get_stats(self): 134 return await self.invoke({"get-stats": None}) 135 136 async def reset_stats(self): 137 return await self.invoke({"reset-stats": None}) 138 139 async def backfill_wait(self): 140 return (await self.invoke({"backfill-wait": None}))["tasks"] 141 142 async def remove(self, where): 143 return await self.invoke({"remove": {"where": where}}) 144 145 async def clean_unused(self, max_age): 146 return await self.invoke({"clean-unused": {"max_age_seconds": max_age}}) 147 148 async def auth(self, username, token): 149 result = await self.invoke({"auth": {"username": username, "token": token}}) 150 self.username = username 151 self.password = token 152 self.saved_become_user = None 153 return result 154 155 async def refresh_token(self, username=None): 156 m = {} 157 if username: 158 m["username"] = username 159 result = await self.invoke({"refresh-token": m}) 160 if ( 161 self.username 162 and not self.saved_become_user 163 and result["username"] == self.username 164 ): 165 self.password = result["token"] 166 return result 167 168 async def set_user_perms(self, username, permissions): 169 return await self.invoke( 170 {"set-user-perms": {"username": username, "permissions": permissions}} 171 ) 172 173 async def get_user(self, username=None): 174 m = {} 175 if username: 176 m["username"] = username 177 return await self.invoke({"get-user": m}) 178 179 async def get_all_users(self): 180 return (await self.invoke({"get-all-users": {}}))["users"] 181 182 async def new_user(self, username, permissions): 183 return await self.invoke( 184 {"new-user": {"username": username, "permissions": permissions}} 185 ) 186 187 async def delete_user(self, username): 188 return await self.invoke({"delete-user": {"username": username}}) 189 190 async def become_user(self, username): 191 result = await self.invoke({"become-user": {"username": username}}) 192 if username == self.username: 193 self.saved_become_user = None 194 else: 195 self.saved_become_user = username 196 return result 197 198 async def get_db_usage(self): 199 return (await self.invoke({"get-db-usage": {}}))["usage"] 200 201 async def get_db_query_columns(self): 202 return (await self.invoke({"get-db-query-columns": {}}))["columns"] 203 204 async def gc_status(self): 205 return await self.invoke({"gc-status": {}}) 206 207 async def gc_mark(self, mark, where): 208 """ 209 Starts a new garbage collection operation identified by "mark". If 210 garbage collection is already in progress with "mark", the collection 211 is continued. 212 213 All unihash entries that match the "where" clause are marked to be 214 kept. In addition, any new entries added to the database after this 215 command will be automatically marked with "mark" 216 """ 217 return await self.invoke({"gc-mark": {"mark": mark, "where": where}}) 218 219 async def gc_sweep(self, mark): 220 """ 221 Finishes garbage collection for "mark". All unihash entries that have 222 not been marked will be deleted. 223 224 It is recommended to clean unused outhash entries after running this to 225 cleanup any dangling outhashes 226 """ 227 return await self.invoke({"gc-sweep": {"mark": mark}}) 228 229 230class Client(bb.asyncrpc.Client): 231 def __init__(self, username=None, password=None): 232 self.username = username 233 self.password = password 234 235 super().__init__() 236 self._add_methods( 237 "connect_tcp", 238 "connect_websocket", 239 "get_unihash", 240 "report_unihash", 241 "report_unihash_equiv", 242 "get_taskhash", 243 "unihash_exists", 244 "get_outhash", 245 "get_stats", 246 "reset_stats", 247 "backfill_wait", 248 "remove", 249 "clean_unused", 250 "auth", 251 "refresh_token", 252 "set_user_perms", 253 "get_user", 254 "get_all_users", 255 "new_user", 256 "delete_user", 257 "become_user", 258 "get_db_usage", 259 "get_db_query_columns", 260 "gc_status", 261 "gc_mark", 262 "gc_sweep", 263 ) 264 265 def _get_async_client(self): 266 return AsyncClient(self.username, self.password) 267 268 269class ClientPool(bb.asyncrpc.ClientPool): 270 def __init__( 271 self, 272 address, 273 max_clients, 274 *, 275 username=None, 276 password=None, 277 become=None, 278 ): 279 super().__init__(max_clients) 280 self.address = address 281 self.username = username 282 self.password = password 283 self.become = become 284 285 async def _new_client(self): 286 client = await create_async_client( 287 self.address, 288 username=self.username, 289 password=self.password, 290 ) 291 if self.become: 292 await client.become_user(self.become) 293 return client 294 295 def _run_key_tasks(self, queries, call): 296 results = {key: None for key in queries.keys()} 297 298 def make_task(key, args): 299 async def task(client): 300 nonlocal results 301 unihash = await call(client, args) 302 results[key] = unihash 303 304 return task 305 306 def gen_tasks(): 307 for key, args in queries.items(): 308 yield make_task(key, args) 309 310 self.run_tasks(gen_tasks()) 311 return results 312 313 def get_unihashes(self, queries): 314 """ 315 Query multiple unihashes in parallel. 316 317 The queries argument is a dictionary with arbitrary key. The values 318 must be a tuple of (method, taskhash). 319 320 Returns a dictionary with a corresponding key for each input key, and 321 the value is the queried unihash (which might be none if the query 322 failed) 323 """ 324 325 async def call(client, args): 326 method, taskhash = args 327 return await client.get_unihash(method, taskhash) 328 329 return self._run_key_tasks(queries, call) 330 331 def unihashes_exist(self, queries): 332 """ 333 Query multiple unihash existence checks in parallel. 334 335 The queries argument is a dictionary with arbitrary key. The values 336 must be a unihash. 337 338 Returns a dictionary with a corresponding key for each input key, and 339 the value is True or False if the unihash is known by the server (or 340 None if there was a failure) 341 """ 342 343 async def call(client, unihash): 344 return await client.unihash_exists(unihash) 345 346 return self._run_key_tasks(queries, call) 347