1# Copyright (C) 2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6import logging 7import socket 8import asyncio 9import bb.asyncrpc 10import json 11from . import create_async_client 12 13 14logger = logging.getLogger("hashserv.client") 15 16 17class Batch(object): 18 def __init__(self): 19 self.done = False 20 self.cond = asyncio.Condition() 21 self.pending = [] 22 self.results = [] 23 self.sent_count = 0 24 25 async def recv(self, socket): 26 while True: 27 async with self.cond: 28 await self.cond.wait_for(lambda: self.pending or self.done) 29 30 if not self.pending: 31 if self.done: 32 return 33 continue 34 35 r = await socket.recv() 36 self.results.append(r) 37 38 async with self.cond: 39 self.pending.pop(0) 40 41 async def send(self, socket, msgs): 42 try: 43 # In the event of a restart due to a reconnect, all in-flight 44 # messages need to be resent first to keep to result count in sync 45 for m in self.pending: 46 await socket.send(m) 47 48 for m in msgs: 49 # Add the message to the pending list before attempting to send 50 # it so that if the send fails it will be retried 51 async with self.cond: 52 self.pending.append(m) 53 self.cond.notify() 54 self.sent_count += 1 55 56 await socket.send(m) 57 58 finally: 59 async with self.cond: 60 self.done = True 61 self.cond.notify() 62 63 async def process(self, socket, msgs): 64 await asyncio.gather( 65 self.recv(socket), 66 self.send(socket, msgs), 67 ) 68 69 if len(self.results) != self.sent_count: 70 raise ValueError( 71 f"Expected result count {len(self.results)}. Expected {self.sent_count}" 72 ) 73 74 return self.results 75 76 77class AsyncClient(bb.asyncrpc.AsyncClient): 78 MODE_NORMAL = 0 79 MODE_GET_STREAM = 1 80 MODE_EXIST_STREAM = 2 81 MODE_MARK_STREAM = 3 82 83 def __init__(self, username=None, password=None): 84 super().__init__("OEHASHEQUIV", "1.1", logger) 85 self.mode = self.MODE_NORMAL 86 self.username = username 87 self.password = password 88 self.saved_become_user = None 89 90 async def setup_connection(self): 91 await super().setup_connection() 92 self.mode = self.MODE_NORMAL 93 if self.username: 94 # Save off become user temporarily because auth() resets it 95 become = self.saved_become_user 96 await self.auth(self.username, self.password) 97 98 if become: 99 await self.become_user(become) 100 101 async def send_stream_batch(self, mode, msgs): 102 """ 103 Does a "batch" process of stream messages. This sends the query 104 messages as fast as possible, and simultaneously attempts to read the 105 messages back. This helps to mitigate the effects of latency to the 106 hash equivalence server be allowing multiple queries to be "in-flight" 107 at once 108 109 The implementation does more complicated tracking using a count of sent 110 messages so that `msgs` can be a generator function (i.e. its length is 111 unknown) 112 113 """ 114 115 b = Batch() 116 117 async def proc(): 118 nonlocal b 119 120 await self._set_mode(mode) 121 return await b.process(self.socket, msgs) 122 123 return await self._send_wrapper(proc) 124 125 async def invoke(self, *args, skip_mode=False, **kwargs): 126 # It's OK if connection errors cause a failure here, because the mode 127 # is also reset to normal on a new connection 128 if not skip_mode: 129 await self._set_mode(self.MODE_NORMAL) 130 return await super().invoke(*args, **kwargs) 131 132 async def _set_mode(self, new_mode): 133 async def stream_to_normal(): 134 # Check if already in normal mode (e.g. due to a connection reset) 135 if self.mode == self.MODE_NORMAL: 136 return "ok" 137 await self.socket.send("END") 138 return await self.socket.recv() 139 140 async def normal_to_stream(command): 141 r = await self.invoke({command: None}, skip_mode=True) 142 if r != "ok": 143 self.check_invoke_error(r) 144 raise ConnectionError( 145 f"Unable to transition to stream mode: Bad response from server {r!r}" 146 ) 147 self.logger.debug("Mode is now %s", command) 148 149 if new_mode == self.mode: 150 return 151 152 self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode) 153 154 # Always transition to normal mode before switching to any other mode 155 if self.mode != self.MODE_NORMAL: 156 r = await self._send_wrapper(stream_to_normal) 157 if r != "ok": 158 self.check_invoke_error(r) 159 raise ConnectionError( 160 f"Unable to transition to normal mode: Bad response from server {r!r}" 161 ) 162 self.logger.debug("Mode is now normal") 163 164 if new_mode == self.MODE_GET_STREAM: 165 await normal_to_stream("get-stream") 166 elif new_mode == self.MODE_EXIST_STREAM: 167 await normal_to_stream("exists-stream") 168 elif new_mode == self.MODE_MARK_STREAM: 169 await normal_to_stream("gc-mark-stream") 170 elif new_mode != self.MODE_NORMAL: 171 raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}") 172 173 self.mode = new_mode 174 175 async def get_unihash(self, method, taskhash): 176 r = await self.get_unihash_batch([(method, taskhash)]) 177 return r[0] 178 179 async def get_unihash_batch(self, args): 180 result = await self.send_stream_batch( 181 self.MODE_GET_STREAM, 182 (f"{method} {taskhash}" for method, taskhash in args), 183 ) 184 return [r if r else None for r in result] 185 186 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): 187 m = extra.copy() 188 m["taskhash"] = taskhash 189 m["method"] = method 190 m["outhash"] = outhash 191 m["unihash"] = unihash 192 return await self.invoke({"report": m}) 193 194 async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): 195 m = extra.copy() 196 m["taskhash"] = taskhash 197 m["method"] = method 198 m["unihash"] = unihash 199 return await self.invoke({"report-equiv": m}) 200 201 async def get_taskhash(self, method, taskhash, all_properties=False): 202 return await self.invoke( 203 {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} 204 ) 205 206 async def unihash_exists(self, unihash): 207 r = await self.unihash_exists_batch([unihash]) 208 return r[0] 209 210 async def unihash_exists_batch(self, unihashes): 211 result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes) 212 return [r == "true" for r in result] 213 214 async def get_outhash(self, method, outhash, taskhash, with_unihash=True): 215 return await self.invoke( 216 { 217 "get-outhash": { 218 "outhash": outhash, 219 "taskhash": taskhash, 220 "method": method, 221 "with_unihash": with_unihash, 222 } 223 } 224 ) 225 226 async def get_stats(self): 227 return await self.invoke({"get-stats": None}) 228 229 async def reset_stats(self): 230 return await self.invoke({"reset-stats": None}) 231 232 async def backfill_wait(self): 233 return (await self.invoke({"backfill-wait": None}))["tasks"] 234 235 async def remove(self, where): 236 return await self.invoke({"remove": {"where": where}}) 237 238 async def clean_unused(self, max_age): 239 return await self.invoke({"clean-unused": {"max_age_seconds": max_age}}) 240 241 async def auth(self, username, token): 242 result = await self.invoke({"auth": {"username": username, "token": token}}) 243 self.username = username 244 self.password = token 245 self.saved_become_user = None 246 return result 247 248 async def refresh_token(self, username=None): 249 m = {} 250 if username: 251 m["username"] = username 252 result = await self.invoke({"refresh-token": m}) 253 if ( 254 self.username 255 and not self.saved_become_user 256 and result["username"] == self.username 257 ): 258 self.password = result["token"] 259 return result 260 261 async def set_user_perms(self, username, permissions): 262 return await self.invoke( 263 {"set-user-perms": {"username": username, "permissions": permissions}} 264 ) 265 266 async def get_user(self, username=None): 267 m = {} 268 if username: 269 m["username"] = username 270 return await self.invoke({"get-user": m}) 271 272 async def get_all_users(self): 273 return (await self.invoke({"get-all-users": {}}))["users"] 274 275 async def new_user(self, username, permissions): 276 return await self.invoke( 277 {"new-user": {"username": username, "permissions": permissions}} 278 ) 279 280 async def delete_user(self, username): 281 return await self.invoke({"delete-user": {"username": username}}) 282 283 async def become_user(self, username): 284 result = await self.invoke({"become-user": {"username": username}}) 285 if username == self.username: 286 self.saved_become_user = None 287 else: 288 self.saved_become_user = username 289 return result 290 291 async def get_db_usage(self): 292 return (await self.invoke({"get-db-usage": {}}))["usage"] 293 294 async def get_db_query_columns(self): 295 return (await self.invoke({"get-db-query-columns": {}}))["columns"] 296 297 async def gc_status(self): 298 return await self.invoke({"gc-status": {}}) 299 300 async def gc_mark(self, mark, where): 301 """ 302 Starts a new garbage collection operation identified by "mark". If 303 garbage collection is already in progress with "mark", the collection 304 is continued. 305 306 All unihash entries that match the "where" clause are marked to be 307 kept. In addition, any new entries added to the database after this 308 command will be automatically marked with "mark" 309 """ 310 return await self.invoke({"gc-mark": {"mark": mark, "where": where}}) 311 312 async def gc_mark_stream(self, mark, rows): 313 """ 314 Similar to `gc-mark`, but accepts a list of "where" key-value pair 315 conditions. It utilizes stream mode to mark hashes, which helps reduce 316 the impact of latency when communicating with the hash equivalence 317 server. 318 """ 319 def row_to_dict(row): 320 pairs = row.split() 321 return dict(zip(pairs[::2], pairs[1::2])) 322 323 responses = await self.send_stream_batch( 324 self.MODE_MARK_STREAM, 325 (json.dumps({"mark": mark, "where": row_to_dict(row)}) for row in rows), 326 ) 327 328 return {"count": sum(int(json.loads(r)["count"]) for r in responses)} 329 330 async def gc_sweep(self, mark): 331 """ 332 Finishes garbage collection for "mark". All unihash entries that have 333 not been marked will be deleted. 334 335 It is recommended to clean unused outhash entries after running this to 336 cleanup any dangling outhashes 337 """ 338 return await self.invoke({"gc-sweep": {"mark": mark}}) 339 340 341class Client(bb.asyncrpc.Client): 342 def __init__(self, username=None, password=None): 343 self.username = username 344 self.password = password 345 346 super().__init__() 347 self._add_methods( 348 "connect_tcp", 349 "connect_websocket", 350 "get_unihash", 351 "get_unihash_batch", 352 "report_unihash", 353 "report_unihash_equiv", 354 "get_taskhash", 355 "unihash_exists", 356 "unihash_exists_batch", 357 "get_outhash", 358 "get_stats", 359 "reset_stats", 360 "backfill_wait", 361 "remove", 362 "clean_unused", 363 "auth", 364 "refresh_token", 365 "set_user_perms", 366 "get_user", 367 "get_all_users", 368 "new_user", 369 "delete_user", 370 "become_user", 371 "get_db_usage", 372 "get_db_query_columns", 373 "gc_status", 374 "gc_mark", 375 "gc_mark_stream", 376 "gc_sweep", 377 ) 378 379 def _get_async_client(self): 380 return AsyncClient(self.username, self.password) 381