1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6import json
7import logging
8import socket
9import os
10
11
12logger = logging.getLogger('hashserv.client')
13
14
15class HashConnectionError(Exception):
16    pass
17
18
19class Client(object):
20    MODE_NORMAL = 0
21    MODE_GET_STREAM = 1
22
23    def __init__(self):
24        self._socket = None
25        self.reader = None
26        self.writer = None
27        self.mode = self.MODE_NORMAL
28
29    def connect_tcp(self, address, port):
30        def connect_sock():
31            s = socket.create_connection((address, port))
32
33            s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
34            s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
35            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
36            return s
37
38        self._connect_sock = connect_sock
39
40    def connect_unix(self, path):
41        def connect_sock():
42            s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
43            # AF_UNIX has path length issues so chdir here to workaround
44            cwd = os.getcwd()
45            try:
46                os.chdir(os.path.dirname(path))
47                s.connect(os.path.basename(path))
48            finally:
49                os.chdir(cwd)
50            return s
51
52        self._connect_sock = connect_sock
53
54    def connect(self):
55        if self._socket is None:
56            self._socket = self._connect_sock()
57
58            self.reader = self._socket.makefile('r', encoding='utf-8')
59            self.writer = self._socket.makefile('w', encoding='utf-8')
60
61            self.writer.write('OEHASHEQUIV 1.0\n\n')
62            self.writer.flush()
63
64            # Restore mode if the socket is being re-created
65            cur_mode = self.mode
66            self.mode = self.MODE_NORMAL
67            self._set_mode(cur_mode)
68
69        return self._socket
70
71    def close(self):
72        if self._socket is not None:
73            self._socket.close()
74            self._socket = None
75            self.reader = None
76            self.writer = None
77
78    def _send_wrapper(self, proc):
79        count = 0
80        while True:
81            try:
82                self.connect()
83                return proc()
84            except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e:
85                logger.warning('Error talking to server: %s' % e)
86                if count >= 3:
87                    if not isinstance(e, HashConnectionError):
88                        raise HashConnectionError(str(e))
89                    raise e
90                self.close()
91                count += 1
92
93    def send_message(self, msg):
94        def proc():
95            self.writer.write('%s\n' % json.dumps(msg))
96            self.writer.flush()
97
98            l = self.reader.readline()
99            if not l:
100                raise HashConnectionError('Connection closed')
101
102            if not l.endswith('\n'):
103                raise HashConnectionError('Bad message %r' % message)
104
105            return json.loads(l)
106
107        return self._send_wrapper(proc)
108
109    def send_stream(self, msg):
110        def proc():
111            self.writer.write("%s\n" % msg)
112            self.writer.flush()
113            l = self.reader.readline()
114            if not l:
115                raise HashConnectionError('Connection closed')
116            return l.rstrip()
117
118        return self._send_wrapper(proc)
119
120    def _set_mode(self, new_mode):
121        if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
122            r = self.send_stream('END')
123            if r != 'ok':
124                raise HashConnectionError('Bad response from server %r' % r)
125        elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
126            r = self.send_message({'get-stream': None})
127            if r != 'ok':
128                raise HashConnectionError('Bad response from server %r' % r)
129        elif new_mode != self.mode:
130            raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode))
131
132        self.mode = new_mode
133
134    def get_unihash(self, method, taskhash):
135        self._set_mode(self.MODE_GET_STREAM)
136        r = self.send_stream('%s %s' % (method, taskhash))
137        if not r:
138            return None
139        return r
140
141    def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
142        self._set_mode(self.MODE_NORMAL)
143        m = extra.copy()
144        m['taskhash'] = taskhash
145        m['method'] = method
146        m['outhash'] = outhash
147        m['unihash'] = unihash
148        return self.send_message({'report': m})
149
150    def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
151        self._set_mode(self.MODE_NORMAL)
152        m = extra.copy()
153        m['taskhash'] = taskhash
154        m['method'] = method
155        m['unihash'] = unihash
156        return self.send_message({'report-equiv': m})
157
158    def get_stats(self):
159        self._set_mode(self.MODE_NORMAL)
160        return self.send_message({'get-stats': None})
161
162    def reset_stats(self):
163        self._set_mode(self.MODE_NORMAL)
164        return self.send_message({'reset-stats': None})
165