1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing, contextmanager
7from datetime import datetime
8import enum
9import asyncio
10import logging
11import math
12import time
13from . import create_async_client, UNIHASH_TABLE_COLUMNS, OUTHASH_TABLE_COLUMNS
14import bb.asyncrpc
15
16
17logger = logging.getLogger('hashserv.server')
18
19
20class Measurement(object):
21    def __init__(self, sample):
22        self.sample = sample
23
24    def start(self):
25        self.start_time = time.perf_counter()
26
27    def end(self):
28        self.sample.add(time.perf_counter() - self.start_time)
29
30    def __enter__(self):
31        self.start()
32        return self
33
34    def __exit__(self, *args, **kwargs):
35        self.end()
36
37
38class Sample(object):
39    def __init__(self, stats):
40        self.stats = stats
41        self.num_samples = 0
42        self.elapsed = 0
43
44    def measure(self):
45        return Measurement(self)
46
47    def __enter__(self):
48        return self
49
50    def __exit__(self, *args, **kwargs):
51        self.end()
52
53    def add(self, elapsed):
54        self.num_samples += 1
55        self.elapsed += elapsed
56
57    def end(self):
58        if self.num_samples:
59            self.stats.add(self.elapsed)
60            self.num_samples = 0
61            self.elapsed = 0
62
63
64class Stats(object):
65    def __init__(self):
66        self.reset()
67
68    def reset(self):
69        self.num = 0
70        self.total_time = 0
71        self.max_time = 0
72        self.m = 0
73        self.s = 0
74        self.current_elapsed = None
75
76    def add(self, elapsed):
77        self.num += 1
78        if self.num == 1:
79            self.m = elapsed
80            self.s = 0
81        else:
82            last_m = self.m
83            self.m = last_m + (elapsed - last_m) / self.num
84            self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
85
86        self.total_time += elapsed
87
88        if self.max_time < elapsed:
89            self.max_time = elapsed
90
91    def start_sample(self):
92        return Sample(self)
93
94    @property
95    def average(self):
96        if self.num == 0:
97            return 0
98        return self.total_time / self.num
99
100    @property
101    def stdev(self):
102        if self.num <= 1:
103            return 0
104        return math.sqrt(self.s / (self.num - 1))
105
106    def todict(self):
107        return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
108
109
110@enum.unique
111class Resolve(enum.Enum):
112    FAIL = enum.auto()
113    IGNORE = enum.auto()
114    REPLACE = enum.auto()
115
116
117def insert_table(cursor, table, data, on_conflict):
118    resolve = {
119        Resolve.FAIL: "",
120        Resolve.IGNORE: " OR IGNORE",
121        Resolve.REPLACE: " OR REPLACE",
122    }[on_conflict]
123
124    keys = sorted(data.keys())
125    query = 'INSERT{resolve} INTO {table} ({fields}) VALUES({values})'.format(
126        resolve=resolve,
127        table=table,
128        fields=", ".join(keys),
129        values=", ".join(":" + k for k in keys),
130    )
131    prevrowid = cursor.lastrowid
132    cursor.execute(query, data)
133    logging.debug(
134        "Inserting %r into %s, %s",
135        data,
136        table,
137        on_conflict
138    )
139    return (cursor.lastrowid, cursor.lastrowid != prevrowid)
140
141def insert_unihash(cursor, data, on_conflict):
142    return insert_table(cursor, "unihashes_v2", data, on_conflict)
143
144def insert_outhash(cursor, data, on_conflict):
145    return insert_table(cursor, "outhashes_v2", data, on_conflict)
146
147async def copy_unihash_from_upstream(client, db, method, taskhash):
148    d = await client.get_taskhash(method, taskhash)
149    if d is not None:
150        with closing(db.cursor()) as cursor:
151            insert_unihash(
152                cursor,
153                {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS},
154                Resolve.IGNORE,
155            )
156            db.commit()
157    return d
158
159
160class ServerCursor(object):
161    def __init__(self, db, cursor, upstream):
162        self.db = db
163        self.cursor = cursor
164        self.upstream = upstream
165
166
167class ServerClient(bb.asyncrpc.AsyncServerConnection):
168    def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
169        super().__init__(reader, writer, 'OEHASHEQUIV', logger)
170        self.db = db
171        self.request_stats = request_stats
172        self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
173        self.backfill_queue = backfill_queue
174        self.upstream = upstream
175
176        self.handlers.update({
177            'get': self.handle_get,
178            'get-outhash': self.handle_get_outhash,
179            'get-stream': self.handle_get_stream,
180            'get-stats': self.handle_get_stats,
181        })
182
183        if not read_only:
184            self.handlers.update({
185                'report': self.handle_report,
186                'report-equiv': self.handle_equivreport,
187                'reset-stats': self.handle_reset_stats,
188                'backfill-wait': self.handle_backfill_wait,
189            })
190
191    def validate_proto_version(self):
192        return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
193
194    async def process_requests(self):
195        if self.upstream is not None:
196            self.upstream_client = await create_async_client(self.upstream)
197        else:
198            self.upstream_client = None
199
200        await super().process_requests()
201
202        if self.upstream_client is not None:
203            await self.upstream_client.close()
204
205    async def dispatch_message(self, msg):
206        for k in self.handlers.keys():
207            if k in msg:
208                logger.debug('Handling %s' % k)
209                if 'stream' in k:
210                    await self.handlers[k](msg[k])
211                else:
212                    with self.request_stats.start_sample() as self.request_sample, \
213                            self.request_sample.measure():
214                        await self.handlers[k](msg[k])
215                return
216
217        raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
218
219    async def handle_get(self, request):
220        method = request['method']
221        taskhash = request['taskhash']
222        fetch_all = request.get('all', False)
223
224        with closing(self.db.cursor()) as cursor:
225            d = await self.get_unihash(cursor, method, taskhash, fetch_all)
226
227        self.write_message(d)
228
229    async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
230        d = None
231
232        if fetch_all:
233            cursor.execute(
234                '''
235                SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
236                INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
237                WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
238                ORDER BY outhashes_v2.created ASC
239                LIMIT 1
240                ''',
241                {
242                    'method': method,
243                    'taskhash': taskhash,
244                }
245
246            )
247            row = cursor.fetchone()
248
249            if row is not None:
250                d = {k: row[k] for k in row.keys()}
251            elif self.upstream_client is not None:
252                d = await self.upstream_client.get_taskhash(method, taskhash, True)
253                self.update_unified(cursor, d)
254                self.db.commit()
255        else:
256            row = self.query_equivalent(cursor, method, taskhash)
257
258            if row is not None:
259                d = {k: row[k] for k in row.keys()}
260            elif self.upstream_client is not None:
261                d = await self.upstream_client.get_taskhash(method, taskhash)
262                d = {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS}
263                insert_unihash(cursor, d, Resolve.IGNORE)
264                self.db.commit()
265
266        return d
267
268    async def handle_get_outhash(self, request):
269        method = request['method']
270        outhash = request['outhash']
271        taskhash = request['taskhash']
272
273        with closing(self.db.cursor()) as cursor:
274            d = await self.get_outhash(cursor, method, outhash, taskhash)
275
276        self.write_message(d)
277
278    async def get_outhash(self, cursor, method, outhash, taskhash):
279        d = None
280        cursor.execute(
281            '''
282            SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
283            INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
284            WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
285            ORDER BY outhashes_v2.created ASC
286            LIMIT 1
287            ''',
288            {
289                'method': method,
290                'outhash': outhash,
291            }
292        )
293        row = cursor.fetchone()
294
295        if row is not None:
296            d = {k: row[k] for k in row.keys()}
297        elif self.upstream_client is not None:
298            d = await self.upstream_client.get_outhash(method, outhash, taskhash)
299            self.update_unified(cursor, d)
300            self.db.commit()
301
302        return d
303
304    def update_unified(self, cursor, data):
305        if data is None:
306            return
307
308        insert_unihash(
309            cursor,
310            {k: v for k, v in data.items() if k in UNIHASH_TABLE_COLUMNS},
311            Resolve.IGNORE
312        )
313        insert_outhash(
314            cursor,
315            {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS},
316            Resolve.IGNORE
317        )
318
319    async def handle_get_stream(self, request):
320        self.write_message('ok')
321
322        while True:
323            upstream = None
324
325            l = await self.reader.readline()
326            if not l:
327                return
328
329            try:
330                # This inner loop is very sensitive and must be as fast as
331                # possible (which is why the request sample is handled manually
332                # instead of using 'with', and also why logging statements are
333                # commented out.
334                self.request_sample = self.request_stats.start_sample()
335                request_measure = self.request_sample.measure()
336                request_measure.start()
337
338                l = l.decode('utf-8').rstrip()
339                if l == 'END':
340                    self.writer.write('ok\n'.encode('utf-8'))
341                    return
342
343                (method, taskhash) = l.split()
344                #logger.debug('Looking up %s %s' % (method, taskhash))
345                cursor = self.db.cursor()
346                try:
347                    row = self.query_equivalent(cursor, method, taskhash)
348                finally:
349                    cursor.close()
350
351                if row is not None:
352                    msg = ('%s\n' % row['unihash']).encode('utf-8')
353                    #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
354                elif self.upstream_client is not None:
355                    upstream = await self.upstream_client.get_unihash(method, taskhash)
356                    if upstream:
357                        msg = ("%s\n" % upstream).encode("utf-8")
358                    else:
359                        msg = "\n".encode("utf-8")
360                else:
361                    msg = '\n'.encode('utf-8')
362
363                self.writer.write(msg)
364            finally:
365                request_measure.end()
366                self.request_sample.end()
367
368            await self.writer.drain()
369
370            # Post to the backfill queue after writing the result to minimize
371            # the turn around time on a request
372            if upstream is not None:
373                await self.backfill_queue.put((method, taskhash))
374
375    async def handle_report(self, data):
376        with closing(self.db.cursor()) as cursor:
377            outhash_data = {
378                'method': data['method'],
379                'outhash': data['outhash'],
380                'taskhash': data['taskhash'],
381                'created': datetime.now()
382            }
383
384            for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
385                if k in data:
386                    outhash_data[k] = data[k]
387
388            # Insert the new entry, unless it already exists
389            (rowid, inserted) = insert_outhash(cursor, outhash_data, Resolve.IGNORE)
390
391            if inserted:
392                # If this row is new, check if it is equivalent to another
393                # output hash
394                cursor.execute(
395                    '''
396                    SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
397                    INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
398                    -- Select any matching output hash except the one we just inserted
399                    WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
400                    -- Pick the oldest hash
401                    ORDER BY outhashes_v2.created ASC
402                    LIMIT 1
403                    ''',
404                    {
405                        'method': data['method'],
406                        'outhash': data['outhash'],
407                        'taskhash': data['taskhash'],
408                    }
409                )
410                row = cursor.fetchone()
411
412                if row is not None:
413                    # A matching output hash was found. Set our taskhash to the
414                    # same unihash since they are equivalent
415                    unihash = row['unihash']
416                    resolve = Resolve.IGNORE
417                else:
418                    # No matching output hash was found. This is probably the
419                    # first outhash to be added.
420                    unihash = data['unihash']
421                    resolve = Resolve.IGNORE
422
423                    # Query upstream to see if it has a unihash we can use
424                    if self.upstream_client is not None:
425                        upstream_data = await self.upstream_client.get_outhash(data['method'], data['outhash'], data['taskhash'])
426                        if upstream_data is not None:
427                            unihash = upstream_data['unihash']
428
429
430                insert_unihash(
431                    cursor,
432                    {
433                        'method': data['method'],
434                        'taskhash': data['taskhash'],
435                        'unihash': unihash,
436                    },
437                    resolve
438                )
439
440            unihash_data = await self.get_unihash(cursor, data['method'], data['taskhash'])
441            if unihash_data is not None:
442                unihash = unihash_data['unihash']
443            else:
444                unihash = data['unihash']
445
446            self.db.commit()
447
448            d = {
449                'taskhash': data['taskhash'],
450                'method': data['method'],
451                'unihash': unihash,
452            }
453
454        self.write_message(d)
455
456    async def handle_equivreport(self, data):
457        with closing(self.db.cursor()) as cursor:
458            insert_data = {
459                'method': data['method'],
460                'taskhash': data['taskhash'],
461                'unihash': data['unihash'],
462            }
463            insert_unihash(cursor, insert_data, Resolve.IGNORE)
464            self.db.commit()
465
466            # Fetch the unihash that will be reported for the taskhash. If the
467            # unihash matches, it means this row was inserted (or the mapping
468            # was already valid)
469            row = self.query_equivalent(cursor, data['method'], data['taskhash'])
470
471            if row['unihash'] == data['unihash']:
472                logger.info('Adding taskhash equivalence for %s with unihash %s',
473                                data['taskhash'], row['unihash'])
474
475            d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
476
477        self.write_message(d)
478
479
480    async def handle_get_stats(self, request):
481        d = {
482            'requests': self.request_stats.todict(),
483        }
484
485        self.write_message(d)
486
487    async def handle_reset_stats(self, request):
488        d = {
489            'requests': self.request_stats.todict(),
490        }
491
492        self.request_stats.reset()
493        self.write_message(d)
494
495    async def handle_backfill_wait(self, request):
496        d = {
497            'tasks': self.backfill_queue.qsize(),
498        }
499        await self.backfill_queue.join()
500        self.write_message(d)
501
502    def query_equivalent(self, cursor, method, taskhash):
503        # This is part of the inner loop and must be as fast as possible
504        cursor.execute(
505            'SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash',
506            {
507                'method': method,
508                'taskhash': taskhash,
509            }
510        )
511        return cursor.fetchone()
512
513
514class Server(bb.asyncrpc.AsyncServer):
515    def __init__(self, db, upstream=None, read_only=False):
516        if upstream and read_only:
517            raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
518
519        super().__init__(logger)
520
521        self.request_stats = Stats()
522        self.db = db
523        self.upstream = upstream
524        self.read_only = read_only
525
526    def accept_client(self, reader, writer):
527        return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
528
529    @contextmanager
530    def _backfill_worker(self):
531        async def backfill_worker_task():
532            client = await create_async_client(self.upstream)
533            try:
534                while True:
535                    item = await self.backfill_queue.get()
536                    if item is None:
537                        self.backfill_queue.task_done()
538                        break
539                    method, taskhash = item
540                    await copy_unihash_from_upstream(client, self.db, method, taskhash)
541                    self.backfill_queue.task_done()
542            finally:
543                await client.close()
544
545        async def join_worker(worker):
546            await self.backfill_queue.put(None)
547            await worker
548
549        if self.upstream is not None:
550            worker = asyncio.ensure_future(backfill_worker_task())
551            try:
552                yield
553            finally:
554                self.loop.run_until_complete(join_worker(worker))
555        else:
556            yield
557
558    def run_loop_forever(self):
559        self.backfill_queue = asyncio.Queue()
560
561        with self._backfill_worker():
562            super().run_loop_forever()
563