1# Copyright (C) 2018-2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6import asyncio 7from contextlib import closing 8import re 9import sqlite3 10import itertools 11import json 12 13UNIX_PREFIX = "unix://" 14 15ADDR_TYPE_UNIX = 0 16ADDR_TYPE_TCP = 1 17 18# The Python async server defaults to a 64K receive buffer, so we hardcode our 19# maximum chunk size. It would be better if the client and server reported to 20# each other what the maximum chunk sizes were, but that will slow down the 21# connection setup with a round trip delay so I'd rather not do that unless it 22# is necessary 23DEFAULT_MAX_CHUNK = 32 * 1024 24 25TABLE_DEFINITION = ( 26 ("method", "TEXT NOT NULL"), 27 ("outhash", "TEXT NOT NULL"), 28 ("taskhash", "TEXT NOT NULL"), 29 ("unihash", "TEXT NOT NULL"), 30 ("created", "DATETIME"), 31 32 # Optional fields 33 ("owner", "TEXT"), 34 ("PN", "TEXT"), 35 ("PV", "TEXT"), 36 ("PR", "TEXT"), 37 ("task", "TEXT"), 38 ("outhash_siginfo", "TEXT"), 39) 40 41TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION) 42 43def setup_database(database, sync=True): 44 db = sqlite3.connect(database) 45 db.row_factory = sqlite3.Row 46 47 with closing(db.cursor()) as cursor: 48 cursor.execute(''' 49 CREATE TABLE IF NOT EXISTS tasks_v2 ( 50 id INTEGER PRIMARY KEY AUTOINCREMENT, 51 %s 52 UNIQUE(method, outhash, taskhash) 53 ) 54 ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION)) 55 cursor.execute('PRAGMA journal_mode = WAL') 56 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) 57 58 # Drop old indexes 59 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup') 60 cursor.execute('DROP INDEX IF EXISTS outhash_lookup') 61 62 # Create new indexes 63 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)') 64 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)') 65 66 return db 67 68 69def parse_address(addr): 70 if addr.startswith(UNIX_PREFIX): 71 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) 72 else: 73 m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) 74 if m is not None: 75 host = m.group('host') 76 port = m.group('port') 77 else: 78 host, port = addr.split(':') 79 80 return (ADDR_TYPE_TCP, (host, int(port))) 81 82 83def chunkify(msg, max_chunk): 84 if len(msg) < max_chunk - 1: 85 yield ''.join((msg, "\n")) 86 else: 87 yield ''.join((json.dumps({ 88 'chunk-stream': None 89 }), "\n")) 90 91 args = [iter(msg)] * (max_chunk - 1) 92 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): 93 yield ''.join(itertools.chain(m, "\n")) 94 yield "\n" 95 96 97def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False): 98 from . import server 99 db = setup_database(dbname, sync=sync) 100 s = server.Server(db, upstream=upstream, read_only=read_only) 101 102 (typ, a) = parse_address(addr) 103 if typ == ADDR_TYPE_UNIX: 104 s.start_unix_server(*a) 105 else: 106 s.start_tcp_server(*a) 107 108 return s 109 110 111def create_client(addr): 112 from . import client 113 c = client.Client() 114 115 (typ, a) = parse_address(addr) 116 if typ == ADDR_TYPE_UNIX: 117 c.connect_unix(*a) 118 else: 119 c.connect_tcp(*a) 120 121 return c 122 123async def create_async_client(addr): 124 from . import client 125 c = client.AsyncClient() 126 127 (typ, a) = parse_address(addr) 128 if typ == ADDR_TYPE_UNIX: 129 await c.connect_unix(*a) 130 else: 131 await c.connect_tcp(*a) 132 133 return c 134