1# Copyright (C) 2018 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6from http.server import BaseHTTPRequestHandler, HTTPServer 7import contextlib 8import urllib.parse 9import sqlite3 10import json 11import traceback 12import logging 13from datetime import datetime 14 15logger = logging.getLogger('hashserv') 16 17class HashEquivalenceServer(BaseHTTPRequestHandler): 18 def log_message(self, f, *args): 19 logger.debug(f, *args) 20 21 def do_GET(self): 22 try: 23 p = urllib.parse.urlparse(self.path) 24 25 if p.path != self.prefix + '/v1/equivalent': 26 self.send_error(404) 27 return 28 29 query = urllib.parse.parse_qs(p.query, strict_parsing=True) 30 method = query['method'][0] 31 taskhash = query['taskhash'][0] 32 33 d = None 34 with contextlib.closing(self.db.cursor()) as cursor: 35 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', 36 {'method': method, 'taskhash': taskhash}) 37 38 row = cursor.fetchone() 39 40 if row is not None: 41 logger.debug('Found equivalent task %s', row['taskhash']) 42 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 43 44 self.send_response(200) 45 self.send_header('Content-Type', 'application/json; charset=utf-8') 46 self.end_headers() 47 self.wfile.write(json.dumps(d).encode('utf-8')) 48 except: 49 logger.exception('Error in GET') 50 self.send_error(400, explain=traceback.format_exc()) 51 return 52 53 def do_POST(self): 54 try: 55 p = urllib.parse.urlparse(self.path) 56 57 if p.path != self.prefix + '/v1/equivalent': 58 self.send_error(404) 59 return 60 61 length = int(self.headers['content-length']) 62 data = json.loads(self.rfile.read(length).decode('utf-8')) 63 64 with contextlib.closing(self.db.cursor()) as cursor: 65 cursor.execute(''' 66 SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND outhash=:outhash 67 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, 68 created ASC 69 LIMIT 1 70 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')}) 71 72 row = cursor.fetchone() 73 74 if row is None or row['taskhash'] != data['taskhash']: 75 unihash = data['unihash'] 76 if row is not None: 77 unihash = row['unihash'] 78 79 insert_data = { 80 'method': data['method'], 81 'outhash': data['outhash'], 82 'taskhash': data['taskhash'], 83 'unihash': unihash, 84 'created': datetime.now() 85 } 86 87 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 88 if k in data: 89 insert_data[k] = data[k] 90 91 cursor.execute('''INSERT INTO tasks_v1 (%s) VALUES (%s)''' % ( 92 ', '.join(sorted(insert_data.keys())), 93 ', '.join(':' + k for k in sorted(insert_data.keys()))), 94 insert_data) 95 96 logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash) 97 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE id=:id', {'id': cursor.lastrowid}) 98 row = cursor.fetchone() 99 100 self.db.commit() 101 102 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 103 104 self.send_response(200) 105 self.send_header('Content-Type', 'application/json; charset=utf-8') 106 self.end_headers() 107 self.wfile.write(json.dumps(d).encode('utf-8')) 108 except: 109 logger.exception('Error in POST') 110 self.send_error(400, explain=traceback.format_exc()) 111 return 112 113def create_server(addr, db, prefix=''): 114 class Handler(HashEquivalenceServer): 115 pass 116 117 Handler.prefix = prefix 118 Handler.db = db 119 db.row_factory = sqlite3.Row 120 121 with contextlib.closing(db.cursor()) as cursor: 122 cursor.execute(''' 123 CREATE TABLE IF NOT EXISTS tasks_v1 ( 124 id INTEGER PRIMARY KEY AUTOINCREMENT, 125 method TEXT NOT NULL, 126 outhash TEXT NOT NULL, 127 taskhash TEXT NOT NULL, 128 unihash TEXT NOT NULL, 129 created DATETIME, 130 131 -- Optional fields 132 owner TEXT, 133 PN TEXT, 134 PV TEXT, 135 PR TEXT, 136 task TEXT, 137 outhash_siginfo TEXT 138 ) 139 ''') 140 141 logger.info('Starting server on %s', addr) 142 return HTTPServer(addr, Handler) 143