1#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7import sqlite3
8import logging
9from contextlib import closing
10from . import User
11
12logger = logging.getLogger("hashserv.sqlite")
13
14UNIHASH_TABLE_DEFINITION = (
15    ("method", "TEXT NOT NULL", "UNIQUE"),
16    ("taskhash", "TEXT NOT NULL", "UNIQUE"),
17    ("unihash", "TEXT NOT NULL", ""),
18    ("gc_mark", "TEXT NOT NULL", ""),
19)
20
21UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
22
23OUTHASH_TABLE_DEFINITION = (
24    ("method", "TEXT NOT NULL", "UNIQUE"),
25    ("taskhash", "TEXT NOT NULL", "UNIQUE"),
26    ("outhash", "TEXT NOT NULL", "UNIQUE"),
27    ("created", "DATETIME", ""),
28    # Optional fields
29    ("owner", "TEXT", ""),
30    ("PN", "TEXT", ""),
31    ("PV", "TEXT", ""),
32    ("PR", "TEXT", ""),
33    ("task", "TEXT", ""),
34    ("outhash_siginfo", "TEXT", ""),
35)
36
37OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
38
39USERS_TABLE_DEFINITION = (
40    ("username", "TEXT NOT NULL", "UNIQUE"),
41    ("token", "TEXT NOT NULL", ""),
42    ("permissions", "TEXT NOT NULL", ""),
43)
44
45USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
46
47
48CONFIG_TABLE_DEFINITION = (
49    ("name", "TEXT NOT NULL", "UNIQUE"),
50    ("value", "TEXT", ""),
51)
52
53CONFIG_TABLE_COLUMNS = tuple(name for name, _, _ in CONFIG_TABLE_DEFINITION)
54
55
56def _make_table(cursor, name, definition):
57    cursor.execute(
58        """
59        CREATE TABLE IF NOT EXISTS {name} (
60            id INTEGER PRIMARY KEY AUTOINCREMENT,
61            {fields}
62            UNIQUE({unique})
63            )
64        """.format(
65            name=name,
66            fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
67            unique=", ".join(
68                name for name, _, flags in definition if "UNIQUE" in flags
69            ),
70        )
71    )
72
73
74def map_user(row):
75    if row is None:
76        return None
77    return User(
78        username=row["username"],
79        permissions=set(row["permissions"].split()),
80    )
81
82
83def _make_condition_statement(columns, condition):
84    where = {}
85    for c in columns:
86        if c in condition and condition[c] is not None:
87            where[c] = condition[c]
88
89    return where, " AND ".join("%s=:%s" % (k, k) for k in where.keys())
90
91
92def _get_sqlite_version(cursor):
93    cursor.execute("SELECT sqlite_version()")
94
95    version = []
96    for v in cursor.fetchone()[0].split("."):
97        try:
98            version.append(int(v))
99        except ValueError:
100            version.append(v)
101
102    return tuple(version)
103
104
105def _schema_table_name(version):
106    if version >= (3, 33):
107        return "sqlite_schema"
108
109    return "sqlite_master"
110
111
112class DatabaseEngine(object):
113    def __init__(self, dbname, sync):
114        self.dbname = dbname
115        self.logger = logger
116        self.sync = sync
117
118    async def create(self):
119        db = sqlite3.connect(self.dbname)
120        db.row_factory = sqlite3.Row
121
122        with closing(db.cursor()) as cursor:
123            _make_table(cursor, "unihashes_v3", UNIHASH_TABLE_DEFINITION)
124            _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
125            _make_table(cursor, "users", USERS_TABLE_DEFINITION)
126            _make_table(cursor, "config", CONFIG_TABLE_DEFINITION)
127
128            cursor.execute("PRAGMA journal_mode = WAL")
129            cursor.execute(
130                "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
131            )
132
133            # Drop old indexes
134            cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
135            cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
136            cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
137            cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
138            cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v3")
139
140            # TODO: Upgrade from tasks_v2?
141            cursor.execute("DROP TABLE IF EXISTS tasks_v2")
142
143            # Create new indexes
144            cursor.execute(
145                "CREATE INDEX IF NOT EXISTS taskhash_lookup_v4 ON unihashes_v3 (method, taskhash)"
146            )
147            cursor.execute(
148                "CREATE INDEX IF NOT EXISTS unihash_lookup_v1 ON unihashes_v3 (unihash)"
149            )
150            cursor.execute(
151                "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
152            )
153            cursor.execute("CREATE INDEX IF NOT EXISTS config_lookup ON config (name)")
154
155            sqlite_version = _get_sqlite_version(cursor)
156
157            cursor.execute(
158                f"""
159                SELECT name FROM {_schema_table_name(sqlite_version)} WHERE type = 'table' AND name = 'unihashes_v2'
160                """
161            )
162            if cursor.fetchone():
163                self.logger.info("Upgrading Unihashes V2 -> V3...")
164                cursor.execute(
165                    """
166                    INSERT INTO unihashes_v3 (id, method, unihash, taskhash, gc_mark)
167                    SELECT id, method, unihash, taskhash, '' FROM unihashes_v2
168                    """
169                )
170                cursor.execute("DROP TABLE unihashes_v2")
171                db.commit()
172                self.logger.info("Upgrade complete")
173
174    def connect(self, logger):
175        return Database(logger, self.dbname, self.sync)
176
177
178class Database(object):
179    def __init__(self, logger, dbname, sync):
180        self.dbname = dbname
181        self.logger = logger
182
183        self.db = sqlite3.connect(self.dbname)
184        self.db.row_factory = sqlite3.Row
185
186        with closing(self.db.cursor()) as cursor:
187            cursor.execute("PRAGMA journal_mode = WAL")
188            cursor.execute(
189                "PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF")
190            )
191
192            self.sqlite_version = _get_sqlite_version(cursor)
193
194    async def __aenter__(self):
195        return self
196
197    async def __aexit__(self, exc_type, exc_value, traceback):
198        await self.close()
199
200    async def _set_config(self, cursor, name, value):
201        cursor.execute(
202            """
203            INSERT OR REPLACE INTO config (id, name, value) VALUES
204            ((SELECT id FROM config WHERE name=:name), :name, :value)
205            """,
206            {
207                "name": name,
208                "value": value,
209            },
210        )
211
212    async def _get_config(self, cursor, name):
213        cursor.execute(
214            "SELECT value FROM config WHERE name=:name",
215            {
216                "name": name,
217            },
218        )
219        row = cursor.fetchone()
220        if row is None:
221            return None
222        return row["value"]
223
224    async def close(self):
225        self.db.close()
226
227    async def get_unihash_by_taskhash_full(self, method, taskhash):
228        with closing(self.db.cursor()) as cursor:
229            cursor.execute(
230                """
231                SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
232                INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
233                WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
234                ORDER BY outhashes_v2.created ASC
235                LIMIT 1
236                """,
237                {
238                    "method": method,
239                    "taskhash": taskhash,
240                },
241            )
242            return cursor.fetchone()
243
244    async def get_unihash_by_outhash(self, method, outhash):
245        with closing(self.db.cursor()) as cursor:
246            cursor.execute(
247                """
248                SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
249                INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
250                WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
251                ORDER BY outhashes_v2.created ASC
252                LIMIT 1
253                """,
254                {
255                    "method": method,
256                    "outhash": outhash,
257                },
258            )
259            return cursor.fetchone()
260
261    async def unihash_exists(self, unihash):
262        with closing(self.db.cursor()) as cursor:
263            cursor.execute(
264                """
265                SELECT * FROM unihashes_v3 WHERE unihash=:unihash
266                LIMIT 1
267                """,
268                {
269                    "unihash": unihash,
270                },
271            )
272            return cursor.fetchone() is not None
273
274    async def get_outhash(self, method, outhash):
275        with closing(self.db.cursor()) as cursor:
276            cursor.execute(
277                """
278                SELECT * FROM outhashes_v2
279                WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
280                ORDER BY outhashes_v2.created ASC
281                LIMIT 1
282                """,
283                {
284                    "method": method,
285                    "outhash": outhash,
286                },
287            )
288            return cursor.fetchone()
289
290    async def get_equivalent_for_outhash(self, method, outhash, taskhash):
291        with closing(self.db.cursor()) as cursor:
292            cursor.execute(
293                """
294                SELECT outhashes_v2.taskhash AS taskhash, unihashes_v3.unihash AS unihash FROM outhashes_v2
295                INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
296                -- Select any matching output hash except the one we just inserted
297                WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
298                -- Pick the oldest hash
299                ORDER BY outhashes_v2.created ASC
300                LIMIT 1
301                """,
302                {
303                    "method": method,
304                    "outhash": outhash,
305                    "taskhash": taskhash,
306                },
307            )
308            return cursor.fetchone()
309
310    async def get_equivalent(self, method, taskhash):
311        with closing(self.db.cursor()) as cursor:
312            cursor.execute(
313                "SELECT taskhash, method, unihash FROM unihashes_v3 WHERE method=:method AND taskhash=:taskhash",
314                {
315                    "method": method,
316                    "taskhash": taskhash,
317                },
318            )
319            return cursor.fetchone()
320
321    async def remove(self, condition):
322        def do_remove(columns, table_name, cursor):
323            where, clause = _make_condition_statement(columns, condition)
324            if where:
325                query = f"DELETE FROM {table_name} WHERE {clause}"
326                cursor.execute(query, where)
327                return cursor.rowcount
328
329            return 0
330
331        count = 0
332        with closing(self.db.cursor()) as cursor:
333            count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
334            count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v3", cursor)
335            self.db.commit()
336
337        return count
338
339    async def get_current_gc_mark(self):
340        with closing(self.db.cursor()) as cursor:
341            return await self._get_config(cursor, "gc-mark")
342
343    async def gc_status(self):
344        with closing(self.db.cursor()) as cursor:
345            cursor.execute(
346                """
347                SELECT COUNT() FROM unihashes_v3 WHERE
348                    gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
349                """
350            )
351            keep_rows = cursor.fetchone()[0]
352
353            cursor.execute(
354                """
355                SELECT COUNT() FROM unihashes_v3 WHERE
356                    gc_mark!=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
357                """
358            )
359            remove_rows = cursor.fetchone()[0]
360
361            current_mark = await self._get_config(cursor, "gc-mark")
362
363            return (keep_rows, remove_rows, current_mark)
364
365    async def gc_mark(self, mark, condition):
366        with closing(self.db.cursor()) as cursor:
367            await self._set_config(cursor, "gc-mark", mark)
368
369            where, clause = _make_condition_statement(UNIHASH_TABLE_COLUMNS, condition)
370
371            new_rows = 0
372            if where:
373                cursor.execute(
374                    f"""
375                    UPDATE unihashes_v3 SET
376                        gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
377                    WHERE {clause}
378                    """,
379                    where,
380                )
381                new_rows = cursor.rowcount
382
383            self.db.commit()
384            return new_rows
385
386    async def gc_sweep(self):
387        with closing(self.db.cursor()) as cursor:
388            # NOTE: COALESCE is not used in this query so that if the current
389            # mark is NULL, nothing will happen
390            cursor.execute(
391                """
392                DELETE FROM unihashes_v3 WHERE
393                    gc_mark!=(SELECT value FROM config WHERE name='gc-mark')
394                """
395            )
396            count = cursor.rowcount
397            await self._set_config(cursor, "gc-mark", None)
398
399            self.db.commit()
400            return count
401
402    async def clean_unused(self, oldest):
403        with closing(self.db.cursor()) as cursor:
404            cursor.execute(
405                """
406                DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
407                    SELECT unihashes_v3.id FROM unihashes_v3 WHERE unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash LIMIT 1
408                )
409                """,
410                {
411                    "oldest": oldest,
412                },
413            )
414            self.db.commit()
415            return cursor.rowcount
416
417    async def insert_unihash(self, method, taskhash, unihash):
418        with closing(self.db.cursor()) as cursor:
419            prevrowid = cursor.lastrowid
420            cursor.execute(
421                """
422                INSERT OR IGNORE INTO unihashes_v3 (method, taskhash, unihash, gc_mark) VALUES
423                    (
424                    :method,
425                    :taskhash,
426                    :unihash,
427                    COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
428                    )
429                """,
430                {
431                    "method": method,
432                    "taskhash": taskhash,
433                    "unihash": unihash,
434                },
435            )
436            self.db.commit()
437            return cursor.lastrowid != prevrowid
438
439    async def insert_outhash(self, data):
440        data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
441        keys = sorted(data.keys())
442        query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
443            fields=", ".join(keys),
444            values=", ".join(":" + k for k in keys),
445        )
446        with closing(self.db.cursor()) as cursor:
447            prevrowid = cursor.lastrowid
448            cursor.execute(query, data)
449            self.db.commit()
450            return cursor.lastrowid != prevrowid
451
452    def _get_user(self, username):
453        with closing(self.db.cursor()) as cursor:
454            cursor.execute(
455                """
456                SELECT username, permissions, token FROM users WHERE username=:username
457                """,
458                {
459                    "username": username,
460                },
461            )
462            return cursor.fetchone()
463
464    async def lookup_user_token(self, username):
465        row = self._get_user(username)
466        if row is None:
467            return None, None
468        return map_user(row), row["token"]
469
470    async def lookup_user(self, username):
471        return map_user(self._get_user(username))
472
473    async def set_user_token(self, username, token):
474        with closing(self.db.cursor()) as cursor:
475            cursor.execute(
476                """
477                UPDATE users SET token=:token WHERE username=:username
478                """,
479                {
480                    "username": username,
481                    "token": token,
482                },
483            )
484            self.db.commit()
485            return cursor.rowcount != 0
486
487    async def set_user_perms(self, username, permissions):
488        with closing(self.db.cursor()) as cursor:
489            cursor.execute(
490                """
491                UPDATE users SET permissions=:permissions WHERE username=:username
492                """,
493                {
494                    "username": username,
495                    "permissions": " ".join(permissions),
496                },
497            )
498            self.db.commit()
499            return cursor.rowcount != 0
500
501    async def get_all_users(self):
502        with closing(self.db.cursor()) as cursor:
503            cursor.execute("SELECT username, permissions FROM users")
504            return [map_user(r) for r in cursor.fetchall()]
505
506    async def new_user(self, username, permissions, token):
507        with closing(self.db.cursor()) as cursor:
508            try:
509                cursor.execute(
510                    """
511                    INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
512                    """,
513                    {
514                        "username": username,
515                        "token": token,
516                        "permissions": " ".join(permissions),
517                    },
518                )
519                self.db.commit()
520                return True
521            except sqlite3.IntegrityError:
522                return False
523
524    async def delete_user(self, username):
525        with closing(self.db.cursor()) as cursor:
526            cursor.execute(
527                """
528                DELETE FROM users WHERE username=:username
529                """,
530                {
531                    "username": username,
532                },
533            )
534            self.db.commit()
535            return cursor.rowcount != 0
536
537    async def get_usage(self):
538        usage = {}
539        with closing(self.db.cursor()) as cursor:
540            cursor.execute(
541                f"""
542                SELECT name FROM {_schema_table_name(self.sqlite_version)} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
543                """
544            )
545            for row in cursor.fetchall():
546                cursor.execute(
547                    """
548                    SELECT COUNT() FROM %s
549                    """
550                    % row["name"],
551                )
552                usage[row["name"]] = {
553                    "rows": cursor.fetchone()[0],
554                }
555        return usage
556
557    async def get_query_columns(self):
558        columns = set()
559        for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
560            if typ.startswith("TEXT"):
561                columns.add(name)
562        return list(columns)
563