1# Copyright (C) 2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6from contextlib import closing, contextmanager 7from datetime import datetime 8import enum 9import asyncio 10import logging 11import math 12import time 13from . import create_async_client, UNIHASH_TABLE_COLUMNS, OUTHASH_TABLE_COLUMNS 14import bb.asyncrpc 15 16 17logger = logging.getLogger('hashserv.server') 18 19 20class Measurement(object): 21 def __init__(self, sample): 22 self.sample = sample 23 24 def start(self): 25 self.start_time = time.perf_counter() 26 27 def end(self): 28 self.sample.add(time.perf_counter() - self.start_time) 29 30 def __enter__(self): 31 self.start() 32 return self 33 34 def __exit__(self, *args, **kwargs): 35 self.end() 36 37 38class Sample(object): 39 def __init__(self, stats): 40 self.stats = stats 41 self.num_samples = 0 42 self.elapsed = 0 43 44 def measure(self): 45 return Measurement(self) 46 47 def __enter__(self): 48 return self 49 50 def __exit__(self, *args, **kwargs): 51 self.end() 52 53 def add(self, elapsed): 54 self.num_samples += 1 55 self.elapsed += elapsed 56 57 def end(self): 58 if self.num_samples: 59 self.stats.add(self.elapsed) 60 self.num_samples = 0 61 self.elapsed = 0 62 63 64class Stats(object): 65 def __init__(self): 66 self.reset() 67 68 def reset(self): 69 self.num = 0 70 self.total_time = 0 71 self.max_time = 0 72 self.m = 0 73 self.s = 0 74 self.current_elapsed = None 75 76 def add(self, elapsed): 77 self.num += 1 78 if self.num == 1: 79 self.m = elapsed 80 self.s = 0 81 else: 82 last_m = self.m 83 self.m = last_m + (elapsed - last_m) / self.num 84 self.s = self.s + (elapsed - last_m) * (elapsed - self.m) 85 86 self.total_time += elapsed 87 88 if self.max_time < elapsed: 89 self.max_time = elapsed 90 91 def start_sample(self): 92 return Sample(self) 93 94 @property 95 def average(self): 96 if self.num == 0: 97 return 0 98 return self.total_time / self.num 99 100 @property 101 def stdev(self): 102 if self.num <= 1: 103 return 0 104 return math.sqrt(self.s / (self.num - 1)) 105 106 def todict(self): 107 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 108 109 110@enum.unique 111class Resolve(enum.Enum): 112 FAIL = enum.auto() 113 IGNORE = enum.auto() 114 REPLACE = enum.auto() 115 116 117def insert_table(cursor, table, data, on_conflict): 118 resolve = { 119 Resolve.FAIL: "", 120 Resolve.IGNORE: " OR IGNORE", 121 Resolve.REPLACE: " OR REPLACE", 122 }[on_conflict] 123 124 keys = sorted(data.keys()) 125 query = 'INSERT{resolve} INTO {table} ({fields}) VALUES({values})'.format( 126 resolve=resolve, 127 table=table, 128 fields=", ".join(keys), 129 values=", ".join(":" + k for k in keys), 130 ) 131 prevrowid = cursor.lastrowid 132 cursor.execute(query, data) 133 logging.debug( 134 "Inserting %r into %s, %s", 135 data, 136 table, 137 on_conflict 138 ) 139 return (cursor.lastrowid, cursor.lastrowid != prevrowid) 140 141def insert_unihash(cursor, data, on_conflict): 142 return insert_table(cursor, "unihashes_v2", data, on_conflict) 143 144def insert_outhash(cursor, data, on_conflict): 145 return insert_table(cursor, "outhashes_v2", data, on_conflict) 146 147async def copy_unihash_from_upstream(client, db, method, taskhash): 148 d = await client.get_taskhash(method, taskhash) 149 if d is not None: 150 with closing(db.cursor()) as cursor: 151 insert_unihash( 152 cursor, 153 {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS}, 154 Resolve.IGNORE, 155 ) 156 db.commit() 157 return d 158 159 160class ServerCursor(object): 161 def __init__(self, db, cursor, upstream): 162 self.db = db 163 self.cursor = cursor 164 self.upstream = upstream 165 166 167class ServerClient(bb.asyncrpc.AsyncServerConnection): 168 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only): 169 super().__init__(reader, writer, 'OEHASHEQUIV', logger) 170 self.db = db 171 self.request_stats = request_stats 172 self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK 173 self.backfill_queue = backfill_queue 174 self.upstream = upstream 175 176 self.handlers.update({ 177 'get': self.handle_get, 178 'get-outhash': self.handle_get_outhash, 179 'get-stream': self.handle_get_stream, 180 'get-stats': self.handle_get_stats, 181 }) 182 183 if not read_only: 184 self.handlers.update({ 185 'report': self.handle_report, 186 'report-equiv': self.handle_equivreport, 187 'reset-stats': self.handle_reset_stats, 188 'backfill-wait': self.handle_backfill_wait, 189 }) 190 191 def validate_proto_version(self): 192 return (self.proto_version > (1, 0) and self.proto_version <= (1, 1)) 193 194 async def process_requests(self): 195 if self.upstream is not None: 196 self.upstream_client = await create_async_client(self.upstream) 197 else: 198 self.upstream_client = None 199 200 await super().process_requests() 201 202 if self.upstream_client is not None: 203 await self.upstream_client.close() 204 205 async def dispatch_message(self, msg): 206 for k in self.handlers.keys(): 207 if k in msg: 208 logger.debug('Handling %s' % k) 209 if 'stream' in k: 210 await self.handlers[k](msg[k]) 211 else: 212 with self.request_stats.start_sample() as self.request_sample, \ 213 self.request_sample.measure(): 214 await self.handlers[k](msg[k]) 215 return 216 217 raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) 218 219 async def handle_get(self, request): 220 method = request['method'] 221 taskhash = request['taskhash'] 222 fetch_all = request.get('all', False) 223 224 with closing(self.db.cursor()) as cursor: 225 d = await self.get_unihash(cursor, method, taskhash, fetch_all) 226 227 self.write_message(d) 228 229 async def get_unihash(self, cursor, method, taskhash, fetch_all=False): 230 d = None 231 232 if fetch_all: 233 cursor.execute( 234 ''' 235 SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2 236 INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash 237 WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash 238 ORDER BY outhashes_v2.created ASC 239 LIMIT 1 240 ''', 241 { 242 'method': method, 243 'taskhash': taskhash, 244 } 245 246 ) 247 row = cursor.fetchone() 248 249 if row is not None: 250 d = {k: row[k] for k in row.keys()} 251 elif self.upstream_client is not None: 252 d = await self.upstream_client.get_taskhash(method, taskhash, True) 253 self.update_unified(cursor, d) 254 self.db.commit() 255 else: 256 row = self.query_equivalent(cursor, method, taskhash) 257 258 if row is not None: 259 d = {k: row[k] for k in row.keys()} 260 elif self.upstream_client is not None: 261 d = await self.upstream_client.get_taskhash(method, taskhash) 262 d = {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS} 263 insert_unihash(cursor, d, Resolve.IGNORE) 264 self.db.commit() 265 266 return d 267 268 async def handle_get_outhash(self, request): 269 method = request['method'] 270 outhash = request['outhash'] 271 taskhash = request['taskhash'] 272 273 with closing(self.db.cursor()) as cursor: 274 d = await self.get_outhash(cursor, method, outhash, taskhash) 275 276 self.write_message(d) 277 278 async def get_outhash(self, cursor, method, outhash, taskhash): 279 d = None 280 cursor.execute( 281 ''' 282 SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2 283 INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash 284 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash 285 ORDER BY outhashes_v2.created ASC 286 LIMIT 1 287 ''', 288 { 289 'method': method, 290 'outhash': outhash, 291 } 292 ) 293 row = cursor.fetchone() 294 295 if row is not None: 296 d = {k: row[k] for k in row.keys()} 297 elif self.upstream_client is not None: 298 d = await self.upstream_client.get_outhash(method, outhash, taskhash) 299 self.update_unified(cursor, d) 300 self.db.commit() 301 302 return d 303 304 def update_unified(self, cursor, data): 305 if data is None: 306 return 307 308 insert_unihash( 309 cursor, 310 {k: v for k, v in data.items() if k in UNIHASH_TABLE_COLUMNS}, 311 Resolve.IGNORE 312 ) 313 insert_outhash( 314 cursor, 315 {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}, 316 Resolve.IGNORE 317 ) 318 319 async def handle_get_stream(self, request): 320 self.write_message('ok') 321 322 while True: 323 upstream = None 324 325 l = await self.reader.readline() 326 if not l: 327 return 328 329 try: 330 # This inner loop is very sensitive and must be as fast as 331 # possible (which is why the request sample is handled manually 332 # instead of using 'with', and also why logging statements are 333 # commented out. 334 self.request_sample = self.request_stats.start_sample() 335 request_measure = self.request_sample.measure() 336 request_measure.start() 337 338 l = l.decode('utf-8').rstrip() 339 if l == 'END': 340 self.writer.write('ok\n'.encode('utf-8')) 341 return 342 343 (method, taskhash) = l.split() 344 #logger.debug('Looking up %s %s' % (method, taskhash)) 345 cursor = self.db.cursor() 346 try: 347 row = self.query_equivalent(cursor, method, taskhash) 348 finally: 349 cursor.close() 350 351 if row is not None: 352 msg = ('%s\n' % row['unihash']).encode('utf-8') 353 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 354 elif self.upstream_client is not None: 355 upstream = await self.upstream_client.get_unihash(method, taskhash) 356 if upstream: 357 msg = ("%s\n" % upstream).encode("utf-8") 358 else: 359 msg = "\n".encode("utf-8") 360 else: 361 msg = '\n'.encode('utf-8') 362 363 self.writer.write(msg) 364 finally: 365 request_measure.end() 366 self.request_sample.end() 367 368 await self.writer.drain() 369 370 # Post to the backfill queue after writing the result to minimize 371 # the turn around time on a request 372 if upstream is not None: 373 await self.backfill_queue.put((method, taskhash)) 374 375 async def handle_report(self, data): 376 with closing(self.db.cursor()) as cursor: 377 outhash_data = { 378 'method': data['method'], 379 'outhash': data['outhash'], 380 'taskhash': data['taskhash'], 381 'created': datetime.now() 382 } 383 384 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 385 if k in data: 386 outhash_data[k] = data[k] 387 388 # Insert the new entry, unless it already exists 389 (rowid, inserted) = insert_outhash(cursor, outhash_data, Resolve.IGNORE) 390 391 if inserted: 392 # If this row is new, check if it is equivalent to another 393 # output hash 394 cursor.execute( 395 ''' 396 SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2 397 INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash 398 -- Select any matching output hash except the one we just inserted 399 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash 400 -- Pick the oldest hash 401 ORDER BY outhashes_v2.created ASC 402 LIMIT 1 403 ''', 404 { 405 'method': data['method'], 406 'outhash': data['outhash'], 407 'taskhash': data['taskhash'], 408 } 409 ) 410 row = cursor.fetchone() 411 412 if row is not None: 413 # A matching output hash was found. Set our taskhash to the 414 # same unihash since they are equivalent 415 unihash = row['unihash'] 416 resolve = Resolve.IGNORE 417 else: 418 # No matching output hash was found. This is probably the 419 # first outhash to be added. 420 unihash = data['unihash'] 421 resolve = Resolve.IGNORE 422 423 # Query upstream to see if it has a unihash we can use 424 if self.upstream_client is not None: 425 upstream_data = await self.upstream_client.get_outhash(data['method'], data['outhash'], data['taskhash']) 426 if upstream_data is not None: 427 unihash = upstream_data['unihash'] 428 429 430 insert_unihash( 431 cursor, 432 { 433 'method': data['method'], 434 'taskhash': data['taskhash'], 435 'unihash': unihash, 436 }, 437 resolve 438 ) 439 440 unihash_data = await self.get_unihash(cursor, data['method'], data['taskhash']) 441 if unihash_data is not None: 442 unihash = unihash_data['unihash'] 443 else: 444 unihash = data['unihash'] 445 446 self.db.commit() 447 448 d = { 449 'taskhash': data['taskhash'], 450 'method': data['method'], 451 'unihash': unihash, 452 } 453 454 self.write_message(d) 455 456 async def handle_equivreport(self, data): 457 with closing(self.db.cursor()) as cursor: 458 insert_data = { 459 'method': data['method'], 460 'taskhash': data['taskhash'], 461 'unihash': data['unihash'], 462 } 463 insert_unihash(cursor, insert_data, Resolve.IGNORE) 464 self.db.commit() 465 466 # Fetch the unihash that will be reported for the taskhash. If the 467 # unihash matches, it means this row was inserted (or the mapping 468 # was already valid) 469 row = self.query_equivalent(cursor, data['method'], data['taskhash']) 470 471 if row['unihash'] == data['unihash']: 472 logger.info('Adding taskhash equivalence for %s with unihash %s', 473 data['taskhash'], row['unihash']) 474 475 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 476 477 self.write_message(d) 478 479 480 async def handle_get_stats(self, request): 481 d = { 482 'requests': self.request_stats.todict(), 483 } 484 485 self.write_message(d) 486 487 async def handle_reset_stats(self, request): 488 d = { 489 'requests': self.request_stats.todict(), 490 } 491 492 self.request_stats.reset() 493 self.write_message(d) 494 495 async def handle_backfill_wait(self, request): 496 d = { 497 'tasks': self.backfill_queue.qsize(), 498 } 499 await self.backfill_queue.join() 500 self.write_message(d) 501 502 def query_equivalent(self, cursor, method, taskhash): 503 # This is part of the inner loop and must be as fast as possible 504 cursor.execute( 505 'SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash', 506 { 507 'method': method, 508 'taskhash': taskhash, 509 } 510 ) 511 return cursor.fetchone() 512 513 514class Server(bb.asyncrpc.AsyncServer): 515 def __init__(self, db, upstream=None, read_only=False): 516 if upstream and read_only: 517 raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server") 518 519 super().__init__(logger) 520 521 self.request_stats = Stats() 522 self.db = db 523 self.upstream = upstream 524 self.read_only = read_only 525 526 def accept_client(self, reader, writer): 527 return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) 528 529 @contextmanager 530 def _backfill_worker(self): 531 async def backfill_worker_task(): 532 client = await create_async_client(self.upstream) 533 try: 534 while True: 535 item = await self.backfill_queue.get() 536 if item is None: 537 self.backfill_queue.task_done() 538 break 539 method, taskhash = item 540 await copy_unihash_from_upstream(client, self.db, method, taskhash) 541 self.backfill_queue.task_done() 542 finally: 543 await client.close() 544 545 async def join_worker(worker): 546 await self.backfill_queue.put(None) 547 await worker 548 549 if self.upstream is not None: 550 worker = asyncio.ensure_future(backfill_worker_task()) 551 try: 552 yield 553 finally: 554 self.loop.run_until_complete(join_worker(worker)) 555 else: 556 yield 557 558 def run_loop_forever(self): 559 self.backfill_queue = asyncio.Queue() 560 561 with self._backfill_worker(): 562 super().run_loop_forever() 563