1# Copyright (C) 2018-2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6import asyncio 7from contextlib import closing 8import re 9import sqlite3 10import itertools 11import json 12 13UNIX_PREFIX = "unix://" 14 15ADDR_TYPE_UNIX = 0 16ADDR_TYPE_TCP = 1 17 18# The Python async server defaults to a 64K receive buffer, so we hardcode our 19# maximum chunk size. It would be better if the client and server reported to 20# each other what the maximum chunk sizes were, but that will slow down the 21# connection setup with a round trip delay so I'd rather not do that unless it 22# is necessary 23DEFAULT_MAX_CHUNK = 32 * 1024 24 25UNIHASH_TABLE_DEFINITION = ( 26 ("method", "TEXT NOT NULL", "UNIQUE"), 27 ("taskhash", "TEXT NOT NULL", "UNIQUE"), 28 ("unihash", "TEXT NOT NULL", ""), 29) 30 31UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION) 32 33OUTHASH_TABLE_DEFINITION = ( 34 ("method", "TEXT NOT NULL", "UNIQUE"), 35 ("taskhash", "TEXT NOT NULL", "UNIQUE"), 36 ("outhash", "TEXT NOT NULL", "UNIQUE"), 37 ("created", "DATETIME", ""), 38 39 # Optional fields 40 ("owner", "TEXT", ""), 41 ("PN", "TEXT", ""), 42 ("PV", "TEXT", ""), 43 ("PR", "TEXT", ""), 44 ("task", "TEXT", ""), 45 ("outhash_siginfo", "TEXT", ""), 46) 47 48OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION) 49 50def _make_table(cursor, name, definition): 51 cursor.execute(''' 52 CREATE TABLE IF NOT EXISTS {name} ( 53 id INTEGER PRIMARY KEY AUTOINCREMENT, 54 {fields} 55 UNIQUE({unique}) 56 ) 57 '''.format( 58 name=name, 59 fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition), 60 unique=", ".join(name for name, _, flags in definition if "UNIQUE" in flags) 61 )) 62 63 64def setup_database(database, sync=True): 65 db = sqlite3.connect(database) 66 db.row_factory = sqlite3.Row 67 68 with closing(db.cursor()) as cursor: 69 _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION) 70 _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION) 71 72 cursor.execute('PRAGMA journal_mode = WAL') 73 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) 74 75 # Drop old indexes 76 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup') 77 cursor.execute('DROP INDEX IF EXISTS outhash_lookup') 78 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup_v2') 79 cursor.execute('DROP INDEX IF EXISTS outhash_lookup_v2') 80 81 # TODO: Upgrade from tasks_v2? 82 cursor.execute('DROP TABLE IF EXISTS tasks_v2') 83 84 # Create new indexes 85 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)') 86 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)') 87 88 return db 89 90 91def parse_address(addr): 92 if addr.startswith(UNIX_PREFIX): 93 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) 94 else: 95 m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) 96 if m is not None: 97 host = m.group('host') 98 port = m.group('port') 99 else: 100 host, port = addr.split(':') 101 102 return (ADDR_TYPE_TCP, (host, int(port))) 103 104 105def chunkify(msg, max_chunk): 106 if len(msg) < max_chunk - 1: 107 yield ''.join((msg, "\n")) 108 else: 109 yield ''.join((json.dumps({ 110 'chunk-stream': None 111 }), "\n")) 112 113 args = [iter(msg)] * (max_chunk - 1) 114 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): 115 yield ''.join(itertools.chain(m, "\n")) 116 yield "\n" 117 118 119def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False): 120 from . import server 121 db = setup_database(dbname, sync=sync) 122 s = server.Server(db, upstream=upstream, read_only=read_only) 123 124 (typ, a) = parse_address(addr) 125 if typ == ADDR_TYPE_UNIX: 126 s.start_unix_server(*a) 127 else: 128 s.start_tcp_server(*a) 129 130 return s 131 132 133def create_client(addr): 134 from . import client 135 c = client.Client() 136 137 (typ, a) = parse_address(addr) 138 if typ == ADDR_TYPE_UNIX: 139 c.connect_unix(*a) 140 else: 141 c.connect_tcp(*a) 142 143 return c 144 145async def create_async_client(addr): 146 from . import client 147 c = client.AsyncClient() 148 149 (typ, a) = parse_address(addr) 150 if typ == ADDR_TYPE_UNIX: 151 await c.connect_unix(*a) 152 else: 153 await c.connect_tcp(*a) 154 155 return c 156