1# Copyright (C) 2018-2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from http.server import BaseHTTPRequestHandler, HTTPServer
7import contextlib
8import urllib.parse
9import sqlite3
10import json
11import traceback
12import logging
13import socketserver
14import queue
15import threading
16import signal
17import socket
18import struct
19from datetime import datetime
20
21logger = logging.getLogger('hashserv')
22
23class HashEquivalenceServer(BaseHTTPRequestHandler):
24    def log_message(self, f, *args):
25        logger.debug(f, *args)
26
27    def opendb(self):
28        self.db = sqlite3.connect(self.dbname)
29        self.db.row_factory = sqlite3.Row
30        self.db.execute("PRAGMA synchronous = OFF;")
31        self.db.execute("PRAGMA journal_mode = MEMORY;")
32
33    def do_GET(self):
34        try:
35            if not self.db:
36                self.opendb()
37
38            p = urllib.parse.urlparse(self.path)
39
40            if p.path != self.prefix + '/v1/equivalent':
41                self.send_error(404)
42                return
43
44            query = urllib.parse.parse_qs(p.query, strict_parsing=True)
45            method = query['method'][0]
46            taskhash = query['taskhash'][0]
47
48            d = None
49            with contextlib.closing(self.db.cursor()) as cursor:
50                cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
51                        {'method': method, 'taskhash': taskhash})
52
53                row = cursor.fetchone()
54
55                if row is not None:
56                    logger.debug('Found equivalent task %s', row['taskhash'])
57                    d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
58
59            self.send_response(200)
60            self.send_header('Content-Type', 'application/json; charset=utf-8')
61            self.end_headers()
62            self.wfile.write(json.dumps(d).encode('utf-8'))
63        except:
64            logger.exception('Error in GET')
65            self.send_error(400, explain=traceback.format_exc())
66            return
67
68    def do_POST(self):
69        try:
70            if not self.db:
71                self.opendb()
72
73            p = urllib.parse.urlparse(self.path)
74
75            if p.path != self.prefix + '/v1/equivalent':
76                self.send_error(404)
77                return
78
79            length = int(self.headers['content-length'])
80            data = json.loads(self.rfile.read(length).decode('utf-8'))
81
82            with contextlib.closing(self.db.cursor()) as cursor:
83                cursor.execute('''
84                    -- Find tasks with a matching outhash (that is, tasks that
85                    -- are equivalent)
86                    SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
87
88                    -- If there is an exact match on the taskhash, return it.
89                    -- Otherwise return the oldest matching outhash of any
90                    -- taskhash
91                    ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
92                        created ASC
93
94                    -- Only return one row
95                    LIMIT 1
96                    ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
97
98                row = cursor.fetchone()
99
100                # If no matching outhash was found, or one *was* found but it
101                # wasn't an exact match on the taskhash, a new entry for this
102                # taskhash should be added
103                if row is None or row['taskhash'] != data['taskhash']:
104                    # If a row matching the outhash was found, the unihash for
105                    # the new taskhash should be the same as that one.
106                    # Otherwise the caller provided unihash is used.
107                    unihash = data['unihash']
108                    if row is not None:
109                        unihash = row['unihash']
110
111                    insert_data = {
112                            'method': data['method'],
113                            'outhash': data['outhash'],
114                            'taskhash': data['taskhash'],
115                            'unihash': unihash,
116                            'created': datetime.now()
117                            }
118
119                    for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
120                        if k in data:
121                            insert_data[k] = data[k]
122
123                    cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
124                            ', '.join(sorted(insert_data.keys())),
125                            ', '.join(':' + k for k in sorted(insert_data.keys()))),
126                        insert_data)
127
128                    logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash)
129
130                    self.db.commit()
131                    d = {'taskhash': data['taskhash'], 'method': data['method'], 'unihash': unihash}
132                else:
133                    d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
134
135                self.send_response(200)
136                self.send_header('Content-Type', 'application/json; charset=utf-8')
137                self.end_headers()
138                self.wfile.write(json.dumps(d).encode('utf-8'))
139        except:
140            logger.exception('Error in POST')
141            self.send_error(400, explain=traceback.format_exc())
142            return
143
144class ThreadedHTTPServer(HTTPServer):
145    quit = False
146
147    def serve_forever(self):
148        self.requestqueue = queue.Queue()
149        self.handlerthread = threading.Thread(target=self.process_request_thread)
150        self.handlerthread.daemon = False
151
152        self.handlerthread.start()
153
154        signal.signal(signal.SIGTERM, self.sigterm_exception)
155        super().serve_forever()
156        os._exit(0)
157
158    def sigterm_exception(self, signum, stackframe):
159        self.server_close()
160        os._exit(0)
161
162    def server_bind(self):
163        HTTPServer.server_bind(self)
164        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
165
166    def process_request_thread(self):
167        while not self.quit:
168            try:
169                (request, client_address) = self.requestqueue.get(True)
170            except queue.Empty:
171                continue
172            if request is None:
173                continue
174            try:
175                self.finish_request(request, client_address)
176            except Exception:
177                self.handle_error(request, client_address)
178            finally:
179                self.shutdown_request(request)
180        os._exit(0)
181
182    def process_request(self, request, client_address):
183        self.requestqueue.put((request, client_address))
184
185    def server_close(self):
186        super().server_close()
187        self.quit = True
188        self.requestqueue.put((None, None))
189        self.handlerthread.join()
190
191def create_server(addr, dbname, prefix=''):
192    class Handler(HashEquivalenceServer):
193        pass
194
195    db = sqlite3.connect(dbname)
196    db.row_factory = sqlite3.Row
197
198    Handler.prefix = prefix
199    Handler.db = None
200    Handler.dbname = dbname
201
202    with contextlib.closing(db.cursor()) as cursor:
203        cursor.execute('''
204            CREATE TABLE IF NOT EXISTS tasks_v2 (
205                id INTEGER PRIMARY KEY AUTOINCREMENT,
206                method TEXT NOT NULL,
207                outhash TEXT NOT NULL,
208                taskhash TEXT NOT NULL,
209                unihash TEXT NOT NULL,
210                created DATETIME,
211
212                -- Optional fields
213                owner TEXT,
214                PN TEXT,
215                PV TEXT,
216                PR TEXT,
217                task TEXT,
218                outhash_siginfo TEXT,
219
220                UNIQUE(method, outhash, taskhash)
221                )
222            ''')
223        cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup ON tasks_v2 (method, taskhash)')
224        cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup ON tasks_v2 (method, outhash)')
225
226    ret = ThreadedHTTPServer(addr, Handler)
227
228    logger.info('Starting server on %s\n', ret.server_port)
229
230    return ret
231