1# Copyright (C) 2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6import logging 7import socket 8import asyncio 9import bb.asyncrpc 10import json 11from . import create_async_client 12 13 14logger = logging.getLogger("hashserv.client") 15 16 17class Batch(object): 18 def __init__(self): 19 self.done = False 20 self.cond = asyncio.Condition() 21 self.pending = [] 22 self.results = [] 23 self.sent_count = 0 24 25 async def recv(self, socket): 26 while True: 27 async with self.cond: 28 await self.cond.wait_for(lambda: self.pending or self.done) 29 30 if not self.pending: 31 if self.done: 32 return 33 continue 34 35 r = await socket.recv() 36 self.results.append(r) 37 38 async with self.cond: 39 self.pending.pop(0) 40 41 async def send(self, socket, msgs): 42 try: 43 # In the event of a restart due to a reconnect, all in-flight 44 # messages need to be resent first to keep to result count in sync 45 for m in self.pending: 46 await socket.send(m) 47 48 for m in msgs: 49 # Add the message to the pending list before attempting to send 50 # it so that if the send fails it will be retried 51 async with self.cond: 52 self.pending.append(m) 53 self.cond.notify() 54 self.sent_count += 1 55 56 await socket.send(m) 57 58 finally: 59 async with self.cond: 60 self.done = True 61 self.cond.notify() 62 63 async def process(self, socket, msgs): 64 await asyncio.gather( 65 self.recv(socket), 66 self.send(socket, msgs), 67 ) 68 69 if len(self.results) != self.sent_count: 70 raise ValueError( 71 f"Expected result count {len(self.results)}. Expected {self.sent_count}" 72 ) 73 74 return self.results 75 76 77class AsyncClient(bb.asyncrpc.AsyncClient): 78 MODE_NORMAL = 0 79 MODE_GET_STREAM = 1 80 MODE_EXIST_STREAM = 2 81 82 def __init__(self, username=None, password=None): 83 super().__init__("OEHASHEQUIV", "1.1", logger) 84 self.mode = self.MODE_NORMAL 85 self.username = username 86 self.password = password 87 self.saved_become_user = None 88 89 async def setup_connection(self): 90 await super().setup_connection() 91 self.mode = self.MODE_NORMAL 92 if self.username: 93 # Save off become user temporarily because auth() resets it 94 become = self.saved_become_user 95 await self.auth(self.username, self.password) 96 97 if become: 98 await self.become_user(become) 99 100 async def send_stream_batch(self, mode, msgs): 101 """ 102 Does a "batch" process of stream messages. This sends the query 103 messages as fast as possible, and simultaneously attempts to read the 104 messages back. This helps to mitigate the effects of latency to the 105 hash equivalence server be allowing multiple queries to be "in-flight" 106 at once 107 108 The implementation does more complicated tracking using a count of sent 109 messages so that `msgs` can be a generator function (i.e. its length is 110 unknown) 111 112 """ 113 114 b = Batch() 115 116 async def proc(): 117 nonlocal b 118 119 await self._set_mode(mode) 120 return await b.process(self.socket, msgs) 121 122 return await self._send_wrapper(proc) 123 124 async def invoke(self, *args, skip_mode=False, **kwargs): 125 # It's OK if connection errors cause a failure here, because the mode 126 # is also reset to normal on a new connection 127 if not skip_mode: 128 await self._set_mode(self.MODE_NORMAL) 129 return await super().invoke(*args, **kwargs) 130 131 async def _set_mode(self, new_mode): 132 async def stream_to_normal(): 133 # Check if already in normal mode (e.g. due to a connection reset) 134 if self.mode == self.MODE_NORMAL: 135 return "ok" 136 await self.socket.send("END") 137 return await self.socket.recv() 138 139 async def normal_to_stream(command): 140 r = await self.invoke({command: None}, skip_mode=True) 141 if r != "ok": 142 self.check_invoke_error(r) 143 raise ConnectionError( 144 f"Unable to transition to stream mode: Bad response from server {r!r}" 145 ) 146 self.logger.debug("Mode is now %s", command) 147 148 if new_mode == self.mode: 149 return 150 151 self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode) 152 153 # Always transition to normal mode before switching to any other mode 154 if self.mode != self.MODE_NORMAL: 155 r = await self._send_wrapper(stream_to_normal) 156 if r != "ok": 157 self.check_invoke_error(r) 158 raise ConnectionError( 159 f"Unable to transition to normal mode: Bad response from server {r!r}" 160 ) 161 self.logger.debug("Mode is now normal") 162 163 if new_mode == self.MODE_GET_STREAM: 164 await normal_to_stream("get-stream") 165 elif new_mode == self.MODE_EXIST_STREAM: 166 await normal_to_stream("exists-stream") 167 elif new_mode != self.MODE_NORMAL: 168 raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}") 169 170 self.mode = new_mode 171 172 async def get_unihash(self, method, taskhash): 173 r = await self.get_unihash_batch([(method, taskhash)]) 174 return r[0] 175 176 async def get_unihash_batch(self, args): 177 result = await self.send_stream_batch( 178 self.MODE_GET_STREAM, 179 (f"{method} {taskhash}" for method, taskhash in args), 180 ) 181 return [r if r else None for r in result] 182 183 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): 184 m = extra.copy() 185 m["taskhash"] = taskhash 186 m["method"] = method 187 m["outhash"] = outhash 188 m["unihash"] = unihash 189 return await self.invoke({"report": m}) 190 191 async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): 192 m = extra.copy() 193 m["taskhash"] = taskhash 194 m["method"] = method 195 m["unihash"] = unihash 196 return await self.invoke({"report-equiv": m}) 197 198 async def get_taskhash(self, method, taskhash, all_properties=False): 199 return await self.invoke( 200 {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} 201 ) 202 203 async def unihash_exists(self, unihash): 204 r = await self.unihash_exists_batch([unihash]) 205 return r[0] 206 207 async def unihash_exists_batch(self, unihashes): 208 result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes) 209 return [r == "true" for r in result] 210 211 async def get_outhash(self, method, outhash, taskhash, with_unihash=True): 212 return await self.invoke( 213 { 214 "get-outhash": { 215 "outhash": outhash, 216 "taskhash": taskhash, 217 "method": method, 218 "with_unihash": with_unihash, 219 } 220 } 221 ) 222 223 async def get_stats(self): 224 return await self.invoke({"get-stats": None}) 225 226 async def reset_stats(self): 227 return await self.invoke({"reset-stats": None}) 228 229 async def backfill_wait(self): 230 return (await self.invoke({"backfill-wait": None}))["tasks"] 231 232 async def remove(self, where): 233 return await self.invoke({"remove": {"where": where}}) 234 235 async def clean_unused(self, max_age): 236 return await self.invoke({"clean-unused": {"max_age_seconds": max_age}}) 237 238 async def auth(self, username, token): 239 result = await self.invoke({"auth": {"username": username, "token": token}}) 240 self.username = username 241 self.password = token 242 self.saved_become_user = None 243 return result 244 245 async def refresh_token(self, username=None): 246 m = {} 247 if username: 248 m["username"] = username 249 result = await self.invoke({"refresh-token": m}) 250 if ( 251 self.username 252 and not self.saved_become_user 253 and result["username"] == self.username 254 ): 255 self.password = result["token"] 256 return result 257 258 async def set_user_perms(self, username, permissions): 259 return await self.invoke( 260 {"set-user-perms": {"username": username, "permissions": permissions}} 261 ) 262 263 async def get_user(self, username=None): 264 m = {} 265 if username: 266 m["username"] = username 267 return await self.invoke({"get-user": m}) 268 269 async def get_all_users(self): 270 return (await self.invoke({"get-all-users": {}}))["users"] 271 272 async def new_user(self, username, permissions): 273 return await self.invoke( 274 {"new-user": {"username": username, "permissions": permissions}} 275 ) 276 277 async def delete_user(self, username): 278 return await self.invoke({"delete-user": {"username": username}}) 279 280 async def become_user(self, username): 281 result = await self.invoke({"become-user": {"username": username}}) 282 if username == self.username: 283 self.saved_become_user = None 284 else: 285 self.saved_become_user = username 286 return result 287 288 async def get_db_usage(self): 289 return (await self.invoke({"get-db-usage": {}}))["usage"] 290 291 async def get_db_query_columns(self): 292 return (await self.invoke({"get-db-query-columns": {}}))["columns"] 293 294 async def gc_status(self): 295 return await self.invoke({"gc-status": {}}) 296 297 async def gc_mark(self, mark, where): 298 """ 299 Starts a new garbage collection operation identified by "mark". If 300 garbage collection is already in progress with "mark", the collection 301 is continued. 302 303 All unihash entries that match the "where" clause are marked to be 304 kept. In addition, any new entries added to the database after this 305 command will be automatically marked with "mark" 306 """ 307 return await self.invoke({"gc-mark": {"mark": mark, "where": where}}) 308 309 async def gc_sweep(self, mark): 310 """ 311 Finishes garbage collection for "mark". All unihash entries that have 312 not been marked will be deleted. 313 314 It is recommended to clean unused outhash entries after running this to 315 cleanup any dangling outhashes 316 """ 317 return await self.invoke({"gc-sweep": {"mark": mark}}) 318 319 320class Client(bb.asyncrpc.Client): 321 def __init__(self, username=None, password=None): 322 self.username = username 323 self.password = password 324 325 super().__init__() 326 self._add_methods( 327 "connect_tcp", 328 "connect_websocket", 329 "get_unihash", 330 "get_unihash_batch", 331 "report_unihash", 332 "report_unihash_equiv", 333 "get_taskhash", 334 "unihash_exists", 335 "unihash_exists_batch", 336 "get_outhash", 337 "get_stats", 338 "reset_stats", 339 "backfill_wait", 340 "remove", 341 "clean_unused", 342 "auth", 343 "refresh_token", 344 "set_user_perms", 345 "get_user", 346 "get_all_users", 347 "new_user", 348 "delete_user", 349 "become_user", 350 "get_db_usage", 351 "get_db_query_columns", 352 "gc_status", 353 "gc_mark", 354 "gc_sweep", 355 ) 356 357 def _get_async_client(self): 358 return AsyncClient(self.username, self.password) 359