1# Copyright (C) 2018-2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6import asyncio
7from contextlib import closing
8import re
9import sqlite3
10import itertools
11import json
12
13UNIX_PREFIX = "unix://"
14
15ADDR_TYPE_UNIX = 0
16ADDR_TYPE_TCP = 1
17
18# The Python async server defaults to a 64K receive buffer, so we hardcode our
19# maximum chunk size. It would be better if the client and server reported to
20# each other what the maximum chunk sizes were, but that will slow down the
21# connection setup with a round trip delay so I'd rather not do that unless it
22# is necessary
23DEFAULT_MAX_CHUNK = 32 * 1024
24
25UNIHASH_TABLE_DEFINITION = (
26    ("method", "TEXT NOT NULL", "UNIQUE"),
27    ("taskhash", "TEXT NOT NULL", "UNIQUE"),
28    ("unihash", "TEXT NOT NULL", ""),
29)
30
31UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
32
33OUTHASH_TABLE_DEFINITION = (
34    ("method", "TEXT NOT NULL", "UNIQUE"),
35    ("taskhash", "TEXT NOT NULL", "UNIQUE"),
36    ("outhash", "TEXT NOT NULL", "UNIQUE"),
37    ("created", "DATETIME", ""),
38
39    # Optional fields
40    ("owner", "TEXT", ""),
41    ("PN", "TEXT", ""),
42    ("PV", "TEXT", ""),
43    ("PR", "TEXT", ""),
44    ("task", "TEXT", ""),
45    ("outhash_siginfo", "TEXT", ""),
46)
47
48OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
49
50def _make_table(cursor, name, definition):
51    cursor.execute('''
52        CREATE TABLE IF NOT EXISTS {name} (
53            id INTEGER PRIMARY KEY AUTOINCREMENT,
54            {fields}
55            UNIQUE({unique})
56            )
57        '''.format(
58            name=name,
59            fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
60            unique=", ".join(name for name, _, flags in definition if "UNIQUE" in flags)
61    ))
62
63
64def setup_database(database, sync=True):
65    db = sqlite3.connect(database)
66    db.row_factory = sqlite3.Row
67
68    with closing(db.cursor()) as cursor:
69        _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
70        _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
71
72        cursor.execute('PRAGMA journal_mode = WAL')
73        cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
74
75        # Drop old indexes
76        cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
77        cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
78        cursor.execute('DROP INDEX IF EXISTS taskhash_lookup_v2')
79        cursor.execute('DROP INDEX IF EXISTS outhash_lookup_v2')
80
81        # TODO: Upgrade from tasks_v2?
82        cursor.execute('DROP TABLE IF EXISTS tasks_v2')
83
84        # Create new indexes
85        cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)')
86        cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)')
87
88    return db
89
90
91def parse_address(addr):
92    if addr.startswith(UNIX_PREFIX):
93        return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
94    else:
95        m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
96        if m is not None:
97            host = m.group('host')
98            port = m.group('port')
99        else:
100            host, port = addr.split(':')
101
102        return (ADDR_TYPE_TCP, (host, int(port)))
103
104
105def chunkify(msg, max_chunk):
106    if len(msg) < max_chunk - 1:
107        yield ''.join((msg, "\n"))
108    else:
109        yield ''.join((json.dumps({
110                'chunk-stream': None
111            }), "\n"))
112
113        args = [iter(msg)] * (max_chunk - 1)
114        for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
115            yield ''.join(itertools.chain(m, "\n"))
116        yield "\n"
117
118
119def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
120    from . import server
121    db = setup_database(dbname, sync=sync)
122    s = server.Server(db, upstream=upstream, read_only=read_only)
123
124    (typ, a) = parse_address(addr)
125    if typ == ADDR_TYPE_UNIX:
126        s.start_unix_server(*a)
127    else:
128        s.start_tcp_server(*a)
129
130    return s
131
132
133def create_client(addr):
134    from . import client
135    c = client.Client()
136
137    (typ, a) = parse_address(addr)
138    if typ == ADDR_TYPE_UNIX:
139        c.connect_unix(*a)
140    else:
141        c.connect_tcp(*a)
142
143    return c
144
145async def create_async_client(addr):
146    from . import client
147    c = client.AsyncClient()
148
149    (typ, a) = parse_address(addr)
150    if typ == ADDR_TYPE_UNIX:
151        await c.connect_unix(*a)
152    else:
153        await c.connect_tcp(*a)
154
155    return c
156