1# Copyright (C) 2019 Garmin Ltd. 2# 3# SPDX-License-Identifier: GPL-2.0-only 4# 5 6import json 7import logging 8import socket 9import os 10 11 12logger = logging.getLogger('hashserv.client') 13 14 15class HashConnectionError(Exception): 16 pass 17 18 19class Client(object): 20 MODE_NORMAL = 0 21 MODE_GET_STREAM = 1 22 23 def __init__(self): 24 self._socket = None 25 self.reader = None 26 self.writer = None 27 self.mode = self.MODE_NORMAL 28 29 def connect_tcp(self, address, port): 30 def connect_sock(): 31 s = socket.create_connection((address, port)) 32 33 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 34 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 35 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 36 return s 37 38 self._connect_sock = connect_sock 39 40 def connect_unix(self, path): 41 def connect_sock(): 42 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 43 # AF_UNIX has path length issues so chdir here to workaround 44 cwd = os.getcwd() 45 try: 46 os.chdir(os.path.dirname(path)) 47 s.connect(os.path.basename(path)) 48 finally: 49 os.chdir(cwd) 50 return s 51 52 self._connect_sock = connect_sock 53 54 def connect(self): 55 if self._socket is None: 56 self._socket = self._connect_sock() 57 58 self.reader = self._socket.makefile('r', encoding='utf-8') 59 self.writer = self._socket.makefile('w', encoding='utf-8') 60 61 self.writer.write('OEHASHEQUIV 1.0\n\n') 62 self.writer.flush() 63 64 # Restore mode if the socket is being re-created 65 cur_mode = self.mode 66 self.mode = self.MODE_NORMAL 67 self._set_mode(cur_mode) 68 69 return self._socket 70 71 def close(self): 72 if self._socket is not None: 73 self._socket.close() 74 self._socket = None 75 self.reader = None 76 self.writer = None 77 78 def _send_wrapper(self, proc): 79 count = 0 80 while True: 81 try: 82 self.connect() 83 return proc() 84 except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e: 85 logger.warning('Error talking to server: %s' % e) 86 if count >= 3: 87 if not isinstance(e, HashConnectionError): 88 raise HashConnectionError(str(e)) 89 raise e 90 self.close() 91 count += 1 92 93 def send_message(self, msg): 94 def proc(): 95 self.writer.write('%s\n' % json.dumps(msg)) 96 self.writer.flush() 97 98 l = self.reader.readline() 99 if not l: 100 raise HashConnectionError('Connection closed') 101 102 if not l.endswith('\n'): 103 raise HashConnectionError('Bad message %r' % message) 104 105 return json.loads(l) 106 107 return self._send_wrapper(proc) 108 109 def send_stream(self, msg): 110 def proc(): 111 self.writer.write("%s\n" % msg) 112 self.writer.flush() 113 l = self.reader.readline() 114 if not l: 115 raise HashConnectionError('Connection closed') 116 return l.rstrip() 117 118 return self._send_wrapper(proc) 119 120 def _set_mode(self, new_mode): 121 if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: 122 r = self.send_stream('END') 123 if r != 'ok': 124 raise HashConnectionError('Bad response from server %r' % r) 125 elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: 126 r = self.send_message({'get-stream': None}) 127 if r != 'ok': 128 raise HashConnectionError('Bad response from server %r' % r) 129 elif new_mode != self.mode: 130 raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode)) 131 132 self.mode = new_mode 133 134 def get_unihash(self, method, taskhash): 135 self._set_mode(self.MODE_GET_STREAM) 136 r = self.send_stream('%s %s' % (method, taskhash)) 137 if not r: 138 return None 139 return r 140 141 def report_unihash(self, taskhash, method, outhash, unihash, extra={}): 142 self._set_mode(self.MODE_NORMAL) 143 m = extra.copy() 144 m['taskhash'] = taskhash 145 m['method'] = method 146 m['outhash'] = outhash 147 m['unihash'] = unihash 148 return self.send_message({'report': m}) 149 150 def report_unihash_equiv(self, taskhash, method, unihash, extra={}): 151 self._set_mode(self.MODE_NORMAL) 152 m = extra.copy() 153 m['taskhash'] = taskhash 154 m['method'] = method 155 m['unihash'] = unihash 156 return self.send_message({'report-equiv': m}) 157 158 def get_stats(self): 159 self._set_mode(self.MODE_NORMAL) 160 return self.send_message({'get-stats': None}) 161 162 def reset_stats(self): 163 self._set_mode(self.MODE_NORMAL) 164 return self.send_message({'reset-stats': None}) 165