1# Copyright (C) 2018-2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6from http.server import BaseHTTPRequestHandler, HTTPServer 7import contextlib 8import urllib.parse 9import sqlite3 10import json 11import traceback 12import logging 13import socketserver 14import queue 15import threading 16import signal 17import socket 18import struct 19from datetime import datetime 20 21logger = logging.getLogger('hashserv') 22 23class HashEquivalenceServer(BaseHTTPRequestHandler): 24 def log_message(self, f, *args): 25 logger.debug(f, *args) 26 27 def opendb(self): 28 self.db = sqlite3.connect(self.dbname) 29 self.db.row_factory = sqlite3.Row 30 self.db.execute("PRAGMA synchronous = OFF;") 31 self.db.execute("PRAGMA journal_mode = MEMORY;") 32 33 def do_GET(self): 34 try: 35 if not self.db: 36 self.opendb() 37 38 p = urllib.parse.urlparse(self.path) 39 40 if p.path != self.prefix + '/v1/equivalent': 41 self.send_error(404) 42 return 43 44 query = urllib.parse.parse_qs(p.query, strict_parsing=True) 45 method = query['method'][0] 46 taskhash = query['taskhash'][0] 47 48 d = None 49 with contextlib.closing(self.db.cursor()) as cursor: 50 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', 51 {'method': method, 'taskhash': taskhash}) 52 53 row = cursor.fetchone() 54 55 if row is not None: 56 logger.debug('Found equivalent task %s', row['taskhash']) 57 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 58 59 self.send_response(200) 60 self.send_header('Content-Type', 'application/json; charset=utf-8') 61 self.end_headers() 62 self.wfile.write(json.dumps(d).encode('utf-8')) 63 except: 64 logger.exception('Error in GET') 65 self.send_error(400, explain=traceback.format_exc()) 66 return 67 68 def do_POST(self): 69 try: 70 if not self.db: 71 self.opendb() 72 73 p = urllib.parse.urlparse(self.path) 74 75 if p.path != self.prefix + '/v1/equivalent': 76 self.send_error(404) 77 return 78 79 length = int(self.headers['content-length']) 80 data = json.loads(self.rfile.read(length).decode('utf-8')) 81 82 with contextlib.closing(self.db.cursor()) as cursor: 83 cursor.execute(''' 84 -- Find tasks with a matching outhash (that is, tasks that 85 -- are equivalent) 86 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash 87 88 -- If there is an exact match on the taskhash, return it. 89 -- Otherwise return the oldest matching outhash of any 90 -- taskhash 91 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, 92 created ASC 93 94 -- Only return one row 95 LIMIT 1 96 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')}) 97 98 row = cursor.fetchone() 99 100 # If no matching outhash was found, or one *was* found but it 101 # wasn't an exact match on the taskhash, a new entry for this 102 # taskhash should be added 103 if row is None or row['taskhash'] != data['taskhash']: 104 # If a row matching the outhash was found, the unihash for 105 # the new taskhash should be the same as that one. 106 # Otherwise the caller provided unihash is used. 107 unihash = data['unihash'] 108 if row is not None: 109 unihash = row['unihash'] 110 111 insert_data = { 112 'method': data['method'], 113 'outhash': data['outhash'], 114 'taskhash': data['taskhash'], 115 'unihash': unihash, 116 'created': datetime.now() 117 } 118 119 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 120 if k in data: 121 insert_data[k] = data[k] 122 123 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( 124 ', '.join(sorted(insert_data.keys())), 125 ', '.join(':' + k for k in sorted(insert_data.keys()))), 126 insert_data) 127 128 logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash) 129 130 self.db.commit() 131 d = {'taskhash': data['taskhash'], 'method': data['method'], 'unihash': unihash} 132 else: 133 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 134 135 self.send_response(200) 136 self.send_header('Content-Type', 'application/json; charset=utf-8') 137 self.end_headers() 138 self.wfile.write(json.dumps(d).encode('utf-8')) 139 except: 140 logger.exception('Error in POST') 141 self.send_error(400, explain=traceback.format_exc()) 142 return 143 144class ThreadedHTTPServer(HTTPServer): 145 quit = False 146 147 def serve_forever(self): 148 self.requestqueue = queue.Queue() 149 self.handlerthread = threading.Thread(target=self.process_request_thread) 150 self.handlerthread.daemon = False 151 152 self.handlerthread.start() 153 154 signal.signal(signal.SIGTERM, self.sigterm_exception) 155 super().serve_forever() 156 os._exit(0) 157 158 def sigterm_exception(self, signum, stackframe): 159 self.server_close() 160 os._exit(0) 161 162 def server_bind(self): 163 HTTPServer.server_bind(self) 164 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) 165 166 def process_request_thread(self): 167 while not self.quit: 168 try: 169 (request, client_address) = self.requestqueue.get(True) 170 except queue.Empty: 171 continue 172 if request is None: 173 continue 174 try: 175 self.finish_request(request, client_address) 176 except Exception: 177 self.handle_error(request, client_address) 178 finally: 179 self.shutdown_request(request) 180 os._exit(0) 181 182 def process_request(self, request, client_address): 183 self.requestqueue.put((request, client_address)) 184 185 def server_close(self): 186 super().server_close() 187 self.quit = True 188 self.requestqueue.put((None, None)) 189 self.handlerthread.join() 190 191def create_server(addr, dbname, prefix=''): 192 class Handler(HashEquivalenceServer): 193 pass 194 195 db = sqlite3.connect(dbname) 196 db.row_factory = sqlite3.Row 197 198 Handler.prefix = prefix 199 Handler.db = None 200 Handler.dbname = dbname 201 202 with contextlib.closing(db.cursor()) as cursor: 203 cursor.execute(''' 204 CREATE TABLE IF NOT EXISTS tasks_v2 ( 205 id INTEGER PRIMARY KEY AUTOINCREMENT, 206 method TEXT NOT NULL, 207 outhash TEXT NOT NULL, 208 taskhash TEXT NOT NULL, 209 unihash TEXT NOT NULL, 210 created DATETIME, 211 212 -- Optional fields 213 owner TEXT, 214 PN TEXT, 215 PV TEXT, 216 PR TEXT, 217 task TEXT, 218 outhash_siginfo TEXT, 219 220 UNIQUE(method, outhash, taskhash) 221 ) 222 ''') 223 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup ON tasks_v2 (method, taskhash)') 224 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup ON tasks_v2 (method, outhash)') 225 226 ret = ThreadedHTTPServer(addr, Handler) 227 228 logger.info('Starting server on %s\n', ret.server_port) 229 230 return ret 231