1# Copyright (C) 2018-2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6import asyncio
7from contextlib import closing
8import re
9import sqlite3
10import itertools
11import json
12
13UNIX_PREFIX = "unix://"
14
15ADDR_TYPE_UNIX = 0
16ADDR_TYPE_TCP = 1
17
18# The Python async server defaults to a 64K receive buffer, so we hardcode our
19# maximum chunk size. It would be better if the client and server reported to
20# each other what the maximum chunk sizes were, but that will slow down the
21# connection setup with a round trip delay so I'd rather not do that unless it
22# is necessary
23DEFAULT_MAX_CHUNK = 32 * 1024
24
25TABLE_DEFINITION = (
26    ("method", "TEXT NOT NULL"),
27    ("outhash", "TEXT NOT NULL"),
28    ("taskhash", "TEXT NOT NULL"),
29    ("unihash", "TEXT NOT NULL"),
30    ("created", "DATETIME"),
31
32    # Optional fields
33    ("owner", "TEXT"),
34    ("PN", "TEXT"),
35    ("PV", "TEXT"),
36    ("PR", "TEXT"),
37    ("task", "TEXT"),
38    ("outhash_siginfo", "TEXT"),
39)
40
41TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION)
42
43def setup_database(database, sync=True):
44    db = sqlite3.connect(database)
45    db.row_factory = sqlite3.Row
46
47    with closing(db.cursor()) as cursor:
48        cursor.execute('''
49            CREATE TABLE IF NOT EXISTS tasks_v2 (
50                id INTEGER PRIMARY KEY AUTOINCREMENT,
51                %s
52                UNIQUE(method, outhash, taskhash)
53                )
54            ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION))
55        cursor.execute('PRAGMA journal_mode = WAL')
56        cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
57
58        # Drop old indexes
59        cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
60        cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
61
62        # Create new indexes
63        cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)')
64        cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)')
65
66    return db
67
68
69def parse_address(addr):
70    if addr.startswith(UNIX_PREFIX):
71        return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
72    else:
73        m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
74        if m is not None:
75            host = m.group('host')
76            port = m.group('port')
77        else:
78            host, port = addr.split(':')
79
80        return (ADDR_TYPE_TCP, (host, int(port)))
81
82
83def chunkify(msg, max_chunk):
84    if len(msg) < max_chunk - 1:
85        yield ''.join((msg, "\n"))
86    else:
87        yield ''.join((json.dumps({
88                'chunk-stream': None
89            }), "\n"))
90
91        args = [iter(msg)] * (max_chunk - 1)
92        for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
93            yield ''.join(itertools.chain(m, "\n"))
94        yield "\n"
95
96
97def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
98    from . import server
99    db = setup_database(dbname, sync=sync)
100    s = server.Server(db, upstream=upstream, read_only=read_only)
101
102    (typ, a) = parse_address(addr)
103    if typ == ADDR_TYPE_UNIX:
104        s.start_unix_server(*a)
105    else:
106        s.start_tcp_server(*a)
107
108    return s
109
110
111def create_client(addr):
112    from . import client
113    c = client.Client()
114
115    (typ, a) = parse_address(addr)
116    if typ == ADDR_TYPE_UNIX:
117        c.connect_unix(*a)
118    else:
119        c.connect_tcp(*a)
120
121    return c
122
123async def create_async_client(addr):
124    from . import client
125    c = client.AsyncClient()
126
127    (typ, a) = parse_address(addr)
128    if typ == ADDR_TYPE_UNIX:
129        await c.connect_unix(*a)
130    else:
131        await c.connect_tcp(*a)
132
133    return c
134