1# Copyright (C) 2018-2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7import re
8import sqlite3
9
10UNIX_PREFIX = "unix://"
11
12ADDR_TYPE_UNIX = 0
13ADDR_TYPE_TCP = 1
14
15
16def setup_database(database, sync=True):
17    db = sqlite3.connect(database)
18    db.row_factory = sqlite3.Row
19
20    with closing(db.cursor()) as cursor:
21        cursor.execute('''
22            CREATE TABLE IF NOT EXISTS tasks_v2 (
23                id INTEGER PRIMARY KEY AUTOINCREMENT,
24                method TEXT NOT NULL,
25                outhash TEXT NOT NULL,
26                taskhash TEXT NOT NULL,
27                unihash TEXT NOT NULL,
28                created DATETIME,
29
30                -- Optional fields
31                owner TEXT,
32                PN TEXT,
33                PV TEXT,
34                PR TEXT,
35                task TEXT,
36                outhash_siginfo TEXT,
37
38                UNIQUE(method, outhash, taskhash)
39                )
40            ''')
41        cursor.execute('PRAGMA journal_mode = WAL')
42        cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
43
44        # Drop old indexes
45        cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
46        cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
47
48        # Create new indexes
49        cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)')
50        cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)')
51
52    return db
53
54
55def parse_address(addr):
56    if addr.startswith(UNIX_PREFIX):
57        return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
58    else:
59        m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
60        if m is not None:
61            host = m.group('host')
62            port = m.group('port')
63        else:
64            host, port = addr.split(':')
65
66        return (ADDR_TYPE_TCP, (host, int(port)))
67
68
69def create_server(addr, dbname, *, sync=True):
70    from . import server
71    db = setup_database(dbname, sync=sync)
72    s = server.Server(db)
73
74    (typ, a) = parse_address(addr)
75    if typ == ADDR_TYPE_UNIX:
76        s.start_unix_server(*a)
77    else:
78        s.start_tcp_server(*a)
79
80    return s
81
82
83def create_client(addr):
84    from . import client
85    c = client.Client()
86
87    (typ, a) = parse_address(addr)
88    if typ == ADDR_TYPE_UNIX:
89        c.connect_unix(*a)
90    else:
91        c.connect_tcp(*a)
92
93    return c
94