1# Copyright (C) 2018-2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7import re
8import sqlite3
9import itertools
10import json
11
12UNIX_PREFIX = "unix://"
13
14ADDR_TYPE_UNIX = 0
15ADDR_TYPE_TCP = 1
16
17# The Python async server defaults to a 64K receive buffer, so we hardcode our
18# maximum chunk size. It would be better if the client and server reported to
19# each other what the maximum chunk sizes were, but that will slow down the
20# connection setup with a round trip delay so I'd rather not do that unless it
21# is necessary
22DEFAULT_MAX_CHUNK = 32 * 1024
23
24def setup_database(database, sync=True):
25    db = sqlite3.connect(database)
26    db.row_factory = sqlite3.Row
27
28    with closing(db.cursor()) as cursor:
29        cursor.execute('''
30            CREATE TABLE IF NOT EXISTS tasks_v2 (
31                id INTEGER PRIMARY KEY AUTOINCREMENT,
32                method TEXT NOT NULL,
33                outhash TEXT NOT NULL,
34                taskhash TEXT NOT NULL,
35                unihash TEXT NOT NULL,
36                created DATETIME,
37
38                -- Optional fields
39                owner TEXT,
40                PN TEXT,
41                PV TEXT,
42                PR TEXT,
43                task TEXT,
44                outhash_siginfo TEXT,
45
46                UNIQUE(method, outhash, taskhash)
47                )
48            ''')
49        cursor.execute('PRAGMA journal_mode = WAL')
50        cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
51
52        # Drop old indexes
53        cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
54        cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
55
56        # Create new indexes
57        cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)')
58        cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)')
59
60    return db
61
62
63def parse_address(addr):
64    if addr.startswith(UNIX_PREFIX):
65        return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
66    else:
67        m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
68        if m is not None:
69            host = m.group('host')
70            port = m.group('port')
71        else:
72            host, port = addr.split(':')
73
74        return (ADDR_TYPE_TCP, (host, int(port)))
75
76
77def chunkify(msg, max_chunk):
78    if len(msg) < max_chunk - 1:
79        yield ''.join((msg, "\n"))
80    else:
81        yield ''.join((json.dumps({
82                'chunk-stream': None
83            }), "\n"))
84
85        args = [iter(msg)] * (max_chunk - 1)
86        for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
87            yield ''.join(itertools.chain(m, "\n"))
88        yield "\n"
89
90
91def create_server(addr, dbname, *, sync=True):
92    from . import server
93    db = setup_database(dbname, sync=sync)
94    s = server.Server(db)
95
96    (typ, a) = parse_address(addr)
97    if typ == ADDR_TYPE_UNIX:
98        s.start_unix_server(*a)
99    else:
100        s.start_tcp_server(*a)
101
102    return s
103
104
105def create_client(addr):
106    from . import client
107    c = client.Client()
108
109    (typ, a) = parse_address(addr)
110    if typ == ADDR_TYPE_UNIX:
111        c.connect_unix(*a)
112    else:
113        c.connect_tcp(*a)
114
115    return c
116