1# Copyright (C) 2018-2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6from contextlib import closing 7import re 8import sqlite3 9import itertools 10import json 11 12UNIX_PREFIX = "unix://" 13 14ADDR_TYPE_UNIX = 0 15ADDR_TYPE_TCP = 1 16 17# The Python async server defaults to a 64K receive buffer, so we hardcode our 18# maximum chunk size. It would be better if the client and server reported to 19# each other what the maximum chunk sizes were, but that will slow down the 20# connection setup with a round trip delay so I'd rather not do that unless it 21# is necessary 22DEFAULT_MAX_CHUNK = 32 * 1024 23 24def setup_database(database, sync=True): 25 db = sqlite3.connect(database) 26 db.row_factory = sqlite3.Row 27 28 with closing(db.cursor()) as cursor: 29 cursor.execute(''' 30 CREATE TABLE IF NOT EXISTS tasks_v2 ( 31 id INTEGER PRIMARY KEY AUTOINCREMENT, 32 method TEXT NOT NULL, 33 outhash TEXT NOT NULL, 34 taskhash TEXT NOT NULL, 35 unihash TEXT NOT NULL, 36 created DATETIME, 37 38 -- Optional fields 39 owner TEXT, 40 PN TEXT, 41 PV TEXT, 42 PR TEXT, 43 task TEXT, 44 outhash_siginfo TEXT, 45 46 UNIQUE(method, outhash, taskhash) 47 ) 48 ''') 49 cursor.execute('PRAGMA journal_mode = WAL') 50 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) 51 52 # Drop old indexes 53 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup') 54 cursor.execute('DROP INDEX IF EXISTS outhash_lookup') 55 56 # Create new indexes 57 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)') 58 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)') 59 60 return db 61 62 63def parse_address(addr): 64 if addr.startswith(UNIX_PREFIX): 65 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) 66 else: 67 m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) 68 if m is not None: 69 host = m.group('host') 70 port = m.group('port') 71 else: 72 host, port = addr.split(':') 73 74 return (ADDR_TYPE_TCP, (host, int(port))) 75 76 77def chunkify(msg, max_chunk): 78 if len(msg) < max_chunk - 1: 79 yield ''.join((msg, "\n")) 80 else: 81 yield ''.join((json.dumps({ 82 'chunk-stream': None 83 }), "\n")) 84 85 args = [iter(msg)] * (max_chunk - 1) 86 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): 87 yield ''.join(itertools.chain(m, "\n")) 88 yield "\n" 89 90 91def create_server(addr, dbname, *, sync=True): 92 from . import server 93 db = setup_database(dbname, sync=sync) 94 s = server.Server(db) 95 96 (typ, a) = parse_address(addr) 97 if typ == ADDR_TYPE_UNIX: 98 s.start_unix_server(*a) 99 else: 100 s.start_tcp_server(*a) 101 102 return s 103 104 105def create_client(addr): 106 from . import client 107 c = client.Client() 108 109 (typ, a) = parse_address(addr) 110 if typ == ADDR_TYPE_UNIX: 111 c.connect_unix(*a) 112 else: 113 c.connect_tcp(*a) 114 115 return c 116