1# Copyright (C) 2018 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from http.server import BaseHTTPRequestHandler, HTTPServer
7import contextlib
8import urllib.parse
9import sqlite3
10import json
11import traceback
12import logging
13from datetime import datetime
14
15logger = logging.getLogger('hashserv')
16
17class HashEquivalenceServer(BaseHTTPRequestHandler):
18    def log_message(self, f, *args):
19        logger.debug(f, *args)
20
21    def do_GET(self):
22        try:
23            p = urllib.parse.urlparse(self.path)
24
25            if p.path != self.prefix + '/v1/equivalent':
26                self.send_error(404)
27                return
28
29            query = urllib.parse.parse_qs(p.query, strict_parsing=True)
30            method = query['method'][0]
31            taskhash = query['taskhash'][0]
32
33            d = None
34            with contextlib.closing(self.db.cursor()) as cursor:
35                cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
36                        {'method': method, 'taskhash': taskhash})
37
38                row = cursor.fetchone()
39
40                if row is not None:
41                    logger.debug('Found equivalent task %s', row['taskhash'])
42                    d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
43
44            self.send_response(200)
45            self.send_header('Content-Type', 'application/json; charset=utf-8')
46            self.end_headers()
47            self.wfile.write(json.dumps(d).encode('utf-8'))
48        except:
49            logger.exception('Error in GET')
50            self.send_error(400, explain=traceback.format_exc())
51            return
52
53    def do_POST(self):
54        try:
55            p = urllib.parse.urlparse(self.path)
56
57            if p.path != self.prefix + '/v1/equivalent':
58                self.send_error(404)
59                return
60
61            length = int(self.headers['content-length'])
62            data = json.loads(self.rfile.read(length).decode('utf-8'))
63
64            with contextlib.closing(self.db.cursor()) as cursor:
65                cursor.execute('''
66                    SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND outhash=:outhash
67                    ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
68                        created ASC
69                    LIMIT 1
70                    ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
71
72                row = cursor.fetchone()
73
74                if row is None or row['taskhash'] != data['taskhash']:
75                    unihash = data['unihash']
76                    if row is not None:
77                        unihash = row['unihash']
78
79                    insert_data = {
80                            'method': data['method'],
81                            'outhash': data['outhash'],
82                            'taskhash': data['taskhash'],
83                            'unihash': unihash,
84                            'created': datetime.now()
85                            }
86
87                    for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
88                        if k in data:
89                            insert_data[k] = data[k]
90
91                    cursor.execute('''INSERT INTO tasks_v1 (%s) VALUES (%s)''' % (
92                            ', '.join(sorted(insert_data.keys())),
93                            ', '.join(':' + k for k in sorted(insert_data.keys()))),
94                        insert_data)
95
96                    logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash)
97                    cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE id=:id', {'id': cursor.lastrowid})
98                    row = cursor.fetchone()
99
100                    self.db.commit()
101
102                d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
103
104                self.send_response(200)
105                self.send_header('Content-Type', 'application/json; charset=utf-8')
106                self.end_headers()
107                self.wfile.write(json.dumps(d).encode('utf-8'))
108        except:
109            logger.exception('Error in POST')
110            self.send_error(400, explain=traceback.format_exc())
111            return
112
113def create_server(addr, db, prefix=''):
114    class Handler(HashEquivalenceServer):
115        pass
116
117    Handler.prefix = prefix
118    Handler.db = db
119    db.row_factory = sqlite3.Row
120
121    with contextlib.closing(db.cursor()) as cursor:
122        cursor.execute('''
123            CREATE TABLE IF NOT EXISTS tasks_v1 (
124                id INTEGER PRIMARY KEY AUTOINCREMENT,
125                method TEXT NOT NULL,
126                outhash TEXT NOT NULL,
127                taskhash TEXT NOT NULL,
128                unihash TEXT NOT NULL,
129                created DATETIME,
130
131                -- Optional fields
132                owner TEXT,
133                PN TEXT,
134                PV TEXT,
135                PR TEXT,
136                task TEXT,
137                outhash_siginfo TEXT
138                )
139            ''')
140
141    logger.info('Starting server on %s', addr)
142    return HTTPServer(addr, Handler)
143