1#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import logging
9from datetime import datetime
10from . import User
11
12from sqlalchemy.ext.asyncio import create_async_engine
13from sqlalchemy.pool import NullPool
14from sqlalchemy import (
15    MetaData,
16    Column,
17    Table,
18    Text,
19    Integer,
20    UniqueConstraint,
21    DateTime,
22    Index,
23    select,
24    insert,
25    exists,
26    literal,
27    and_,
28    delete,
29    update,
30    func,
31    inspect,
32)
33import sqlalchemy.engine
34from sqlalchemy.orm import declarative_base
35from sqlalchemy.exc import IntegrityError
36from sqlalchemy.dialects.postgresql import insert as postgres_insert
37
38Base = declarative_base()
39
40
41class UnihashesV3(Base):
42    __tablename__ = "unihashes_v3"
43    id = Column(Integer, primary_key=True, autoincrement=True)
44    method = Column(Text, nullable=False)
45    taskhash = Column(Text, nullable=False)
46    unihash = Column(Text, nullable=False)
47    gc_mark = Column(Text, nullable=False)
48
49    __table_args__ = (
50        UniqueConstraint("method", "taskhash"),
51        Index("taskhash_lookup_v4", "method", "taskhash"),
52        Index("unihash_lookup_v1", "unihash"),
53    )
54
55
56class OuthashesV2(Base):
57    __tablename__ = "outhashes_v2"
58    id = Column(Integer, primary_key=True, autoincrement=True)
59    method = Column(Text, nullable=False)
60    taskhash = Column(Text, nullable=False)
61    outhash = Column(Text, nullable=False)
62    created = Column(DateTime)
63    owner = Column(Text)
64    PN = Column(Text)
65    PV = Column(Text)
66    PR = Column(Text)
67    task = Column(Text)
68    outhash_siginfo = Column(Text)
69
70    __table_args__ = (
71        UniqueConstraint("method", "taskhash", "outhash"),
72        Index("outhash_lookup_v3", "method", "outhash"),
73    )
74
75
76class Users(Base):
77    __tablename__ = "users"
78    id = Column(Integer, primary_key=True, autoincrement=True)
79    username = Column(Text, nullable=False)
80    token = Column(Text, nullable=False)
81    permissions = Column(Text)
82
83    __table_args__ = (UniqueConstraint("username"),)
84
85
86class Config(Base):
87    __tablename__ = "config"
88    id = Column(Integer, primary_key=True, autoincrement=True)
89    name = Column(Text, nullable=False)
90    value = Column(Text)
91    __table_args__ = (
92        UniqueConstraint("name"),
93        Index("config_lookup", "name"),
94    )
95
96
97#
98# Old table versions
99#
100DeprecatedBase = declarative_base()
101
102
103class UnihashesV2(DeprecatedBase):
104    __tablename__ = "unihashes_v2"
105    id = Column(Integer, primary_key=True, autoincrement=True)
106    method = Column(Text, nullable=False)
107    taskhash = Column(Text, nullable=False)
108    unihash = Column(Text, nullable=False)
109
110    __table_args__ = (
111        UniqueConstraint("method", "taskhash"),
112        Index("taskhash_lookup_v3", "method", "taskhash"),
113    )
114
115
116class DatabaseEngine(object):
117    def __init__(self, url, username=None, password=None):
118        self.logger = logging.getLogger("hashserv.sqlalchemy")
119        self.url = sqlalchemy.engine.make_url(url)
120
121        if username is not None:
122            self.url = self.url.set(username=username)
123
124        if password is not None:
125            self.url = self.url.set(password=password)
126
127    async def create(self):
128        def check_table_exists(conn, name):
129            return inspect(conn).has_table(name)
130
131        self.logger.info("Using database %s", self.url)
132        if self.url.drivername == 'postgresql+psycopg':
133            # Psygopg 3 (psygopg) driver can handle async connection pooling
134            self.engine = create_async_engine(self.url, max_overflow=-1)
135        else:
136            self.engine = create_async_engine(self.url, poolclass=NullPool)
137
138        async with self.engine.begin() as conn:
139            # Create tables
140            self.logger.info("Creating tables...")
141            await conn.run_sync(Base.metadata.create_all)
142
143            if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__):
144                self.logger.info("Upgrading Unihashes V2 -> V3...")
145                statement = insert(UnihashesV3).from_select(
146                    ["id", "method", "unihash", "taskhash", "gc_mark"],
147                    select(
148                        UnihashesV2.id,
149                        UnihashesV2.method,
150                        UnihashesV2.unihash,
151                        UnihashesV2.taskhash,
152                        literal("").label("gc_mark"),
153                    ),
154                )
155                self.logger.debug("%s", statement)
156                await conn.execute(statement)
157
158                await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__])
159                self.logger.info("Upgrade complete")
160
161    def connect(self, logger):
162        return Database(self.engine, logger)
163
164
165def map_row(row):
166    if row is None:
167        return None
168    return dict(**row._mapping)
169
170
171def map_user(row):
172    if row is None:
173        return None
174    return User(
175        username=row.username,
176        permissions=set(row.permissions.split()),
177    )
178
179
180def _make_condition_statement(table, condition):
181    where = {}
182    for c in table.__table__.columns:
183        if c.key in condition and condition[c.key] is not None:
184            where[c] = condition[c.key]
185
186    return [(k == v) for k, v in where.items()]
187
188
189class Database(object):
190    def __init__(self, engine, logger):
191        self.engine = engine
192        self.db = None
193        self.logger = logger
194
195    async def __aenter__(self):
196        self.db = await self.engine.connect()
197        return self
198
199    async def __aexit__(self, exc_type, exc_value, traceback):
200        await self.close()
201
202    async def close(self):
203        await self.db.close()
204        self.db = None
205
206    async def _execute(self, statement):
207        self.logger.debug("%s", statement)
208        return await self.db.execute(statement)
209
210    async def _set_config(self, name, value):
211        while True:
212            result = await self._execute(
213                update(Config).where(Config.name == name).values(value=value)
214            )
215
216            if result.rowcount == 0:
217                self.logger.debug("Config '%s' not found. Adding it", name)
218                try:
219                    await self._execute(insert(Config).values(name=name, value=value))
220                except IntegrityError:
221                    # Race. Try again
222                    continue
223
224            break
225
226    def _get_config_subquery(self, name, default=None):
227        if default is not None:
228            return func.coalesce(
229                select(Config.value).where(Config.name == name).scalar_subquery(),
230                default,
231            )
232        return select(Config.value).where(Config.name == name).scalar_subquery()
233
234    async def _get_config(self, name):
235        result = await self._execute(select(Config.value).where(Config.name == name))
236        row = result.first()
237        if row is None:
238            return None
239        return row.value
240
241    async def get_unihash_by_taskhash_full(self, method, taskhash):
242        async with self.db.begin():
243            result = await self._execute(
244                select(
245                    OuthashesV2,
246                    UnihashesV3.unihash.label("unihash"),
247                )
248                .join(
249                    UnihashesV3,
250                    and_(
251                        UnihashesV3.method == OuthashesV2.method,
252                        UnihashesV3.taskhash == OuthashesV2.taskhash,
253                    ),
254                )
255                .where(
256                    OuthashesV2.method == method,
257                    OuthashesV2.taskhash == taskhash,
258                )
259                .order_by(
260                    OuthashesV2.created.asc(),
261                )
262                .limit(1)
263            )
264            return map_row(result.first())
265
266    async def get_unihash_by_outhash(self, method, outhash):
267        async with self.db.begin():
268            result = await self._execute(
269                select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
270                .join(
271                    UnihashesV3,
272                    and_(
273                        UnihashesV3.method == OuthashesV2.method,
274                        UnihashesV3.taskhash == OuthashesV2.taskhash,
275                    ),
276                )
277                .where(
278                    OuthashesV2.method == method,
279                    OuthashesV2.outhash == outhash,
280                )
281                .order_by(
282                    OuthashesV2.created.asc(),
283                )
284                .limit(1)
285            )
286            return map_row(result.first())
287
288    async def unihash_exists(self, unihash):
289        async with self.db.begin():
290            result = await self._execute(
291                select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1)
292            )
293
294            return result.first() is not None
295
296    async def get_outhash(self, method, outhash):
297        async with self.db.begin():
298            result = await self._execute(
299                select(OuthashesV2)
300                .where(
301                    OuthashesV2.method == method,
302                    OuthashesV2.outhash == outhash,
303                )
304                .order_by(
305                    OuthashesV2.created.asc(),
306                )
307                .limit(1)
308            )
309            return map_row(result.first())
310
311    async def get_equivalent_for_outhash(self, method, outhash, taskhash):
312        async with self.db.begin():
313            result = await self._execute(
314                select(
315                    OuthashesV2.taskhash.label("taskhash"),
316                    UnihashesV3.unihash.label("unihash"),
317                )
318                .join(
319                    UnihashesV3,
320                    and_(
321                        UnihashesV3.method == OuthashesV2.method,
322                        UnihashesV3.taskhash == OuthashesV2.taskhash,
323                    ),
324                )
325                .where(
326                    OuthashesV2.method == method,
327                    OuthashesV2.outhash == outhash,
328                    OuthashesV2.taskhash != taskhash,
329                )
330                .order_by(
331                    OuthashesV2.created.asc(),
332                )
333                .limit(1)
334            )
335            return map_row(result.first())
336
337    async def get_equivalent(self, method, taskhash):
338        async with self.db.begin():
339            result = await self._execute(
340                select(
341                    UnihashesV3.unihash,
342                    UnihashesV3.method,
343                    UnihashesV3.taskhash,
344                ).where(
345                    UnihashesV3.method == method,
346                    UnihashesV3.taskhash == taskhash,
347                )
348            )
349            return map_row(result.first())
350
351    async def remove(self, condition):
352        async def do_remove(table):
353            where = _make_condition_statement(table, condition)
354            if where:
355                async with self.db.begin():
356                    result = await self._execute(delete(table).where(*where))
357                return result.rowcount
358
359            return 0
360
361        count = 0
362        count += await do_remove(UnihashesV3)
363        count += await do_remove(OuthashesV2)
364
365        return count
366
367    async def get_current_gc_mark(self):
368        async with self.db.begin():
369            return await self._get_config("gc-mark")
370
371    async def gc_status(self):
372        async with self.db.begin():
373            gc_mark_subquery = self._get_config_subquery("gc-mark", "")
374
375            result = await self._execute(
376                select(func.count())
377                .select_from(UnihashesV3)
378                .where(UnihashesV3.gc_mark == gc_mark_subquery)
379            )
380            keep_rows = result.scalar()
381
382            result = await self._execute(
383                select(func.count())
384                .select_from(UnihashesV3)
385                .where(UnihashesV3.gc_mark != gc_mark_subquery)
386            )
387            remove_rows = result.scalar()
388
389            return (keep_rows, remove_rows, await self._get_config("gc-mark"))
390
391    async def gc_mark(self, mark, condition):
392        async with self.db.begin():
393            await self._set_config("gc-mark", mark)
394
395            where = _make_condition_statement(UnihashesV3, condition)
396            if not where:
397                return 0
398
399            result = await self._execute(
400                update(UnihashesV3)
401                .values(gc_mark=self._get_config_subquery("gc-mark", ""))
402                .where(*where)
403            )
404            return result.rowcount
405
406    async def gc_sweep(self):
407        async with self.db.begin():
408            result = await self._execute(
409                delete(UnihashesV3).where(
410                    # A sneaky conditional that provides some errant use
411                    # protection: If the config mark is NULL, this will not
412                    # match any rows because No default is specified in the
413                    # select statement
414                    UnihashesV3.gc_mark
415                    != self._get_config_subquery("gc-mark")
416                )
417            )
418            await self._set_config("gc-mark", None)
419
420            return result.rowcount
421
422    async def clean_unused(self, oldest):
423        async with self.db.begin():
424            result = await self._execute(
425                delete(OuthashesV2).where(
426                    OuthashesV2.created < oldest,
427                    ~(
428                        select(UnihashesV3.id)
429                        .where(
430                            UnihashesV3.method == OuthashesV2.method,
431                            UnihashesV3.taskhash == OuthashesV2.taskhash,
432                        )
433                        .limit(1)
434                        .exists()
435                    ),
436                )
437            )
438            return result.rowcount
439
440    async def insert_unihash(self, method, taskhash, unihash):
441        # Postgres specific ignore on insert duplicate
442        if self.engine.name == "postgresql":
443            statement = (
444                postgres_insert(UnihashesV3)
445                .values(
446                    method=method,
447                    taskhash=taskhash,
448                    unihash=unihash,
449                    gc_mark=self._get_config_subquery("gc-mark", ""),
450                )
451                .on_conflict_do_nothing(index_elements=("method", "taskhash"))
452            )
453        else:
454            statement = insert(UnihashesV3).values(
455                method=method,
456                taskhash=taskhash,
457                unihash=unihash,
458                gc_mark=self._get_config_subquery("gc-mark", ""),
459            )
460
461        try:
462            async with self.db.begin():
463                result = await self._execute(statement)
464                return result.rowcount != 0
465        except IntegrityError:
466            self.logger.debug(
467                "%s, %s, %s already in unihash database", method, taskhash, unihash
468            )
469            return False
470
471    async def insert_outhash(self, data):
472        outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
473
474        data = {k: v for k, v in data.items() if k in outhash_columns}
475
476        if "created" in data and not isinstance(data["created"], datetime):
477            data["created"] = datetime.fromisoformat(data["created"])
478
479        # Postgres specific ignore on insert duplicate
480        if self.engine.name == "postgresql":
481            statement = (
482                postgres_insert(OuthashesV2)
483                .values(**data)
484                .on_conflict_do_nothing(
485                    index_elements=("method", "taskhash", "outhash")
486                )
487            )
488        else:
489            statement = insert(OuthashesV2).values(**data)
490
491        try:
492            async with self.db.begin():
493                result = await self._execute(statement)
494                return result.rowcount != 0
495        except IntegrityError:
496            self.logger.debug(
497                "%s, %s already in outhash database", data["method"], data["outhash"]
498            )
499            return False
500
501    async def _get_user(self, username):
502        async with self.db.begin():
503            result = await self._execute(
504                select(
505                    Users.username,
506                    Users.permissions,
507                    Users.token,
508                ).where(
509                    Users.username == username,
510                )
511            )
512            return result.first()
513
514    async def lookup_user_token(self, username):
515        row = await self._get_user(username)
516        if not row:
517            return None, None
518        return map_user(row), row.token
519
520    async def lookup_user(self, username):
521        return map_user(await self._get_user(username))
522
523    async def set_user_token(self, username, token):
524        async with self.db.begin():
525            result = await self._execute(
526                update(Users)
527                .where(
528                    Users.username == username,
529                )
530                .values(
531                    token=token,
532                )
533            )
534            return result.rowcount != 0
535
536    async def set_user_perms(self, username, permissions):
537        async with self.db.begin():
538            result = await self._execute(
539                update(Users)
540                .where(Users.username == username)
541                .values(permissions=" ".join(permissions))
542            )
543            return result.rowcount != 0
544
545    async def get_all_users(self):
546        async with self.db.begin():
547            result = await self._execute(
548                select(
549                    Users.username,
550                    Users.permissions,
551                )
552            )
553            return [map_user(row) for row in result]
554
555    async def new_user(self, username, permissions, token):
556        try:
557            async with self.db.begin():
558                await self._execute(
559                    insert(Users).values(
560                        username=username,
561                        permissions=" ".join(permissions),
562                        token=token,
563                    )
564                )
565            return True
566        except IntegrityError as e:
567            self.logger.debug("Cannot create new user %s: %s", username, e)
568            return False
569
570    async def delete_user(self, username):
571        async with self.db.begin():
572            result = await self._execute(
573                delete(Users).where(Users.username == username)
574            )
575            return result.rowcount != 0
576
577    async def get_usage(self):
578        usage = {}
579        async with self.db.begin() as session:
580            for name, table in Base.metadata.tables.items():
581                result = await self._execute(
582                    statement=select(func.count()).select_from(table)
583                )
584                usage[name] = {
585                    "rows": result.scalar(),
586                }
587
588        return usage
589
590    async def get_query_columns(self):
591        columns = set()
592        for table in (UnihashesV3, OuthashesV2):
593            for c in table.__table__.columns:
594                if not isinstance(c.type, Text):
595                    continue
596                columns.add(c.key)
597
598        return list(columns)
599