1#! /usr/bin/env python3 2# 3# Copyright (C) 2023 Garmin Ltd. 4# 5# SPDX-License-Identifier: GPL-2.0-only 6# 7 8import logging 9from datetime import datetime 10from . import User 11 12from sqlalchemy.ext.asyncio import create_async_engine 13from sqlalchemy.pool import NullPool 14from sqlalchemy import ( 15 MetaData, 16 Column, 17 Table, 18 Text, 19 Integer, 20 UniqueConstraint, 21 DateTime, 22 Index, 23 select, 24 insert, 25 exists, 26 literal, 27 and_, 28 delete, 29 update, 30 func, 31 inspect, 32) 33import sqlalchemy.engine 34from sqlalchemy.orm import declarative_base 35from sqlalchemy.exc import IntegrityError 36from sqlalchemy.dialects.postgresql import insert as postgres_insert 37 38Base = declarative_base() 39 40 41class UnihashesV3(Base): 42 __tablename__ = "unihashes_v3" 43 id = Column(Integer, primary_key=True, autoincrement=True) 44 method = Column(Text, nullable=False) 45 taskhash = Column(Text, nullable=False) 46 unihash = Column(Text, nullable=False) 47 gc_mark = Column(Text, nullable=False) 48 49 __table_args__ = ( 50 UniqueConstraint("method", "taskhash"), 51 Index("taskhash_lookup_v4", "method", "taskhash"), 52 Index("unihash_lookup_v1", "unihash"), 53 ) 54 55 56class OuthashesV2(Base): 57 __tablename__ = "outhashes_v2" 58 id = Column(Integer, primary_key=True, autoincrement=True) 59 method = Column(Text, nullable=False) 60 taskhash = Column(Text, nullable=False) 61 outhash = Column(Text, nullable=False) 62 created = Column(DateTime) 63 owner = Column(Text) 64 PN = Column(Text) 65 PV = Column(Text) 66 PR = Column(Text) 67 task = Column(Text) 68 outhash_siginfo = Column(Text) 69 70 __table_args__ = ( 71 UniqueConstraint("method", "taskhash", "outhash"), 72 Index("outhash_lookup_v3", "method", "outhash"), 73 ) 74 75 76class Users(Base): 77 __tablename__ = "users" 78 id = Column(Integer, primary_key=True, autoincrement=True) 79 username = Column(Text, nullable=False) 80 token = Column(Text, nullable=False) 81 permissions = Column(Text) 82 83 __table_args__ = (UniqueConstraint("username"),) 84 85 86class Config(Base): 87 __tablename__ = "config" 88 id = Column(Integer, primary_key=True, autoincrement=True) 89 name = Column(Text, nullable=False) 90 value = Column(Text) 91 __table_args__ = ( 92 UniqueConstraint("name"), 93 Index("config_lookup", "name"), 94 ) 95 96 97# 98# Old table versions 99# 100DeprecatedBase = declarative_base() 101 102 103class UnihashesV2(DeprecatedBase): 104 __tablename__ = "unihashes_v2" 105 id = Column(Integer, primary_key=True, autoincrement=True) 106 method = Column(Text, nullable=False) 107 taskhash = Column(Text, nullable=False) 108 unihash = Column(Text, nullable=False) 109 110 __table_args__ = ( 111 UniqueConstraint("method", "taskhash"), 112 Index("taskhash_lookup_v3", "method", "taskhash"), 113 ) 114 115 116class DatabaseEngine(object): 117 def __init__(self, url, username=None, password=None): 118 self.logger = logging.getLogger("hashserv.sqlalchemy") 119 self.url = sqlalchemy.engine.make_url(url) 120 121 if username is not None: 122 self.url = self.url.set(username=username) 123 124 if password is not None: 125 self.url = self.url.set(password=password) 126 127 async def create(self): 128 def check_table_exists(conn, name): 129 return inspect(conn).has_table(name) 130 131 self.logger.info("Using database %s", self.url) 132 if self.url.drivername == 'postgresql+psycopg': 133 # Psygopg 3 (psygopg) driver can handle async connection pooling 134 self.engine = create_async_engine(self.url, max_overflow=-1) 135 else: 136 self.engine = create_async_engine(self.url, poolclass=NullPool) 137 138 async with self.engine.begin() as conn: 139 # Create tables 140 self.logger.info("Creating tables...") 141 await conn.run_sync(Base.metadata.create_all) 142 143 if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__): 144 self.logger.info("Upgrading Unihashes V2 -> V3...") 145 statement = insert(UnihashesV3).from_select( 146 ["id", "method", "unihash", "taskhash", "gc_mark"], 147 select( 148 UnihashesV2.id, 149 UnihashesV2.method, 150 UnihashesV2.unihash, 151 UnihashesV2.taskhash, 152 literal("").label("gc_mark"), 153 ), 154 ) 155 self.logger.debug("%s", statement) 156 await conn.execute(statement) 157 158 await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__]) 159 self.logger.info("Upgrade complete") 160 161 def connect(self, logger): 162 return Database(self.engine, logger) 163 164 165def map_row(row): 166 if row is None: 167 return None 168 return dict(**row._mapping) 169 170 171def map_user(row): 172 if row is None: 173 return None 174 return User( 175 username=row.username, 176 permissions=set(row.permissions.split()), 177 ) 178 179 180def _make_condition_statement(table, condition): 181 where = {} 182 for c in table.__table__.columns: 183 if c.key in condition and condition[c.key] is not None: 184 where[c] = condition[c.key] 185 186 return [(k == v) for k, v in where.items()] 187 188 189class Database(object): 190 def __init__(self, engine, logger): 191 self.engine = engine 192 self.db = None 193 self.logger = logger 194 195 async def __aenter__(self): 196 self.db = await self.engine.connect() 197 return self 198 199 async def __aexit__(self, exc_type, exc_value, traceback): 200 await self.close() 201 202 async def close(self): 203 await self.db.close() 204 self.db = None 205 206 async def _execute(self, statement): 207 self.logger.debug("%s", statement) 208 return await self.db.execute(statement) 209 210 async def _set_config(self, name, value): 211 while True: 212 result = await self._execute( 213 update(Config).where(Config.name == name).values(value=value) 214 ) 215 216 if result.rowcount == 0: 217 self.logger.debug("Config '%s' not found. Adding it", name) 218 try: 219 await self._execute(insert(Config).values(name=name, value=value)) 220 except IntegrityError: 221 # Race. Try again 222 continue 223 224 break 225 226 def _get_config_subquery(self, name, default=None): 227 if default is not None: 228 return func.coalesce( 229 select(Config.value).where(Config.name == name).scalar_subquery(), 230 default, 231 ) 232 return select(Config.value).where(Config.name == name).scalar_subquery() 233 234 async def _get_config(self, name): 235 result = await self._execute(select(Config.value).where(Config.name == name)) 236 row = result.first() 237 if row is None: 238 return None 239 return row.value 240 241 async def get_unihash_by_taskhash_full(self, method, taskhash): 242 async with self.db.begin(): 243 result = await self._execute( 244 select( 245 OuthashesV2, 246 UnihashesV3.unihash.label("unihash"), 247 ) 248 .join( 249 UnihashesV3, 250 and_( 251 UnihashesV3.method == OuthashesV2.method, 252 UnihashesV3.taskhash == OuthashesV2.taskhash, 253 ), 254 ) 255 .where( 256 OuthashesV2.method == method, 257 OuthashesV2.taskhash == taskhash, 258 ) 259 .order_by( 260 OuthashesV2.created.asc(), 261 ) 262 .limit(1) 263 ) 264 return map_row(result.first()) 265 266 async def get_unihash_by_outhash(self, method, outhash): 267 async with self.db.begin(): 268 result = await self._execute( 269 select(OuthashesV2, UnihashesV3.unihash.label("unihash")) 270 .join( 271 UnihashesV3, 272 and_( 273 UnihashesV3.method == OuthashesV2.method, 274 UnihashesV3.taskhash == OuthashesV2.taskhash, 275 ), 276 ) 277 .where( 278 OuthashesV2.method == method, 279 OuthashesV2.outhash == outhash, 280 ) 281 .order_by( 282 OuthashesV2.created.asc(), 283 ) 284 .limit(1) 285 ) 286 return map_row(result.first()) 287 288 async def unihash_exists(self, unihash): 289 async with self.db.begin(): 290 result = await self._execute( 291 select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1) 292 ) 293 294 return result.first() is not None 295 296 async def get_outhash(self, method, outhash): 297 async with self.db.begin(): 298 result = await self._execute( 299 select(OuthashesV2) 300 .where( 301 OuthashesV2.method == method, 302 OuthashesV2.outhash == outhash, 303 ) 304 .order_by( 305 OuthashesV2.created.asc(), 306 ) 307 .limit(1) 308 ) 309 return map_row(result.first()) 310 311 async def get_equivalent_for_outhash(self, method, outhash, taskhash): 312 async with self.db.begin(): 313 result = await self._execute( 314 select( 315 OuthashesV2.taskhash.label("taskhash"), 316 UnihashesV3.unihash.label("unihash"), 317 ) 318 .join( 319 UnihashesV3, 320 and_( 321 UnihashesV3.method == OuthashesV2.method, 322 UnihashesV3.taskhash == OuthashesV2.taskhash, 323 ), 324 ) 325 .where( 326 OuthashesV2.method == method, 327 OuthashesV2.outhash == outhash, 328 OuthashesV2.taskhash != taskhash, 329 ) 330 .order_by( 331 OuthashesV2.created.asc(), 332 ) 333 .limit(1) 334 ) 335 return map_row(result.first()) 336 337 async def get_equivalent(self, method, taskhash): 338 async with self.db.begin(): 339 result = await self._execute( 340 select( 341 UnihashesV3.unihash, 342 UnihashesV3.method, 343 UnihashesV3.taskhash, 344 ).where( 345 UnihashesV3.method == method, 346 UnihashesV3.taskhash == taskhash, 347 ) 348 ) 349 return map_row(result.first()) 350 351 async def remove(self, condition): 352 async def do_remove(table): 353 where = _make_condition_statement(table, condition) 354 if where: 355 async with self.db.begin(): 356 result = await self._execute(delete(table).where(*where)) 357 return result.rowcount 358 359 return 0 360 361 count = 0 362 count += await do_remove(UnihashesV3) 363 count += await do_remove(OuthashesV2) 364 365 return count 366 367 async def get_current_gc_mark(self): 368 async with self.db.begin(): 369 return await self._get_config("gc-mark") 370 371 async def gc_status(self): 372 async with self.db.begin(): 373 gc_mark_subquery = self._get_config_subquery("gc-mark", "") 374 375 result = await self._execute( 376 select(func.count()) 377 .select_from(UnihashesV3) 378 .where(UnihashesV3.gc_mark == gc_mark_subquery) 379 ) 380 keep_rows = result.scalar() 381 382 result = await self._execute( 383 select(func.count()) 384 .select_from(UnihashesV3) 385 .where(UnihashesV3.gc_mark != gc_mark_subquery) 386 ) 387 remove_rows = result.scalar() 388 389 return (keep_rows, remove_rows, await self._get_config("gc-mark")) 390 391 async def gc_mark(self, mark, condition): 392 async with self.db.begin(): 393 await self._set_config("gc-mark", mark) 394 395 where = _make_condition_statement(UnihashesV3, condition) 396 if not where: 397 return 0 398 399 result = await self._execute( 400 update(UnihashesV3) 401 .values(gc_mark=self._get_config_subquery("gc-mark", "")) 402 .where(*where) 403 ) 404 return result.rowcount 405 406 async def gc_sweep(self): 407 async with self.db.begin(): 408 result = await self._execute( 409 delete(UnihashesV3).where( 410 # A sneaky conditional that provides some errant use 411 # protection: If the config mark is NULL, this will not 412 # match any rows because No default is specified in the 413 # select statement 414 UnihashesV3.gc_mark 415 != self._get_config_subquery("gc-mark") 416 ) 417 ) 418 await self._set_config("gc-mark", None) 419 420 return result.rowcount 421 422 async def clean_unused(self, oldest): 423 async with self.db.begin(): 424 result = await self._execute( 425 delete(OuthashesV2).where( 426 OuthashesV2.created < oldest, 427 ~( 428 select(UnihashesV3.id) 429 .where( 430 UnihashesV3.method == OuthashesV2.method, 431 UnihashesV3.taskhash == OuthashesV2.taskhash, 432 ) 433 .limit(1) 434 .exists() 435 ), 436 ) 437 ) 438 return result.rowcount 439 440 async def insert_unihash(self, method, taskhash, unihash): 441 # Postgres specific ignore on insert duplicate 442 if self.engine.name == "postgresql": 443 statement = ( 444 postgres_insert(UnihashesV3) 445 .values( 446 method=method, 447 taskhash=taskhash, 448 unihash=unihash, 449 gc_mark=self._get_config_subquery("gc-mark", ""), 450 ) 451 .on_conflict_do_nothing(index_elements=("method", "taskhash")) 452 ) 453 else: 454 statement = insert(UnihashesV3).values( 455 method=method, 456 taskhash=taskhash, 457 unihash=unihash, 458 gc_mark=self._get_config_subquery("gc-mark", ""), 459 ) 460 461 try: 462 async with self.db.begin(): 463 result = await self._execute(statement) 464 return result.rowcount != 0 465 except IntegrityError: 466 self.logger.debug( 467 "%s, %s, %s already in unihash database", method, taskhash, unihash 468 ) 469 return False 470 471 async def insert_outhash(self, data): 472 outhash_columns = set(c.key for c in OuthashesV2.__table__.columns) 473 474 data = {k: v for k, v in data.items() if k in outhash_columns} 475 476 if "created" in data and not isinstance(data["created"], datetime): 477 data["created"] = datetime.fromisoformat(data["created"]) 478 479 # Postgres specific ignore on insert duplicate 480 if self.engine.name == "postgresql": 481 statement = ( 482 postgres_insert(OuthashesV2) 483 .values(**data) 484 .on_conflict_do_nothing( 485 index_elements=("method", "taskhash", "outhash") 486 ) 487 ) 488 else: 489 statement = insert(OuthashesV2).values(**data) 490 491 try: 492 async with self.db.begin(): 493 result = await self._execute(statement) 494 return result.rowcount != 0 495 except IntegrityError: 496 self.logger.debug( 497 "%s, %s already in outhash database", data["method"], data["outhash"] 498 ) 499 return False 500 501 async def _get_user(self, username): 502 async with self.db.begin(): 503 result = await self._execute( 504 select( 505 Users.username, 506 Users.permissions, 507 Users.token, 508 ).where( 509 Users.username == username, 510 ) 511 ) 512 return result.first() 513 514 async def lookup_user_token(self, username): 515 row = await self._get_user(username) 516 if not row: 517 return None, None 518 return map_user(row), row.token 519 520 async def lookup_user(self, username): 521 return map_user(await self._get_user(username)) 522 523 async def set_user_token(self, username, token): 524 async with self.db.begin(): 525 result = await self._execute( 526 update(Users) 527 .where( 528 Users.username == username, 529 ) 530 .values( 531 token=token, 532 ) 533 ) 534 return result.rowcount != 0 535 536 async def set_user_perms(self, username, permissions): 537 async with self.db.begin(): 538 result = await self._execute( 539 update(Users) 540 .where(Users.username == username) 541 .values(permissions=" ".join(permissions)) 542 ) 543 return result.rowcount != 0 544 545 async def get_all_users(self): 546 async with self.db.begin(): 547 result = await self._execute( 548 select( 549 Users.username, 550 Users.permissions, 551 ) 552 ) 553 return [map_user(row) for row in result] 554 555 async def new_user(self, username, permissions, token): 556 try: 557 async with self.db.begin(): 558 await self._execute( 559 insert(Users).values( 560 username=username, 561 permissions=" ".join(permissions), 562 token=token, 563 ) 564 ) 565 return True 566 except IntegrityError as e: 567 self.logger.debug("Cannot create new user %s: %s", username, e) 568 return False 569 570 async def delete_user(self, username): 571 async with self.db.begin(): 572 result = await self._execute( 573 delete(Users).where(Users.username == username) 574 ) 575 return result.rowcount != 0 576 577 async def get_usage(self): 578 usage = {} 579 async with self.db.begin() as session: 580 for name, table in Base.metadata.tables.items(): 581 result = await self._execute( 582 statement=select(func.count()).select_from(table) 583 ) 584 usage[name] = { 585 "rows": result.scalar(), 586 } 587 588 return usage 589 590 async def get_query_columns(self): 591 columns = set() 592 for table in (UnihashesV3, OuthashesV2): 593 for c in table.__table__.columns: 594 if not isinstance(c.type, Text): 595 continue 596 columns.add(c.key) 597 598 return list(columns) 599