xref: /openbmc/openbmc/poky/bitbake/lib/prserv/serv.py (revision f1e5d6968976c2341c6d554bfcc8895f1b33c26b)
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import os,sys,logging
8import signal, time
9import socket
10import io
11import sqlite3
12import prserv
13import prserv.db
14import errno
15from . import create_async_client, revision_smaller, increase_revision
16import bb.asyncrpc
17
18logger = logging.getLogger("BitBake.PRserv")
19
20PIDPREFIX = "/tmp/PRServer_%s_%s.pid"
21singleton = None
22
23class PRServerClient(bb.asyncrpc.AsyncServerConnection):
24    def __init__(self, socket, server):
25        super().__init__(socket, "PRSERVICE", server.logger)
26        self.server = server
27
28        self.handlers.update({
29            "get-pr": self.handle_get_pr,
30            "test-pr": self.handle_test_pr,
31            "test-package": self.handle_test_package,
32            "max-package-pr": self.handle_max_package_pr,
33            "import-one": self.handle_import_one,
34            "export": self.handle_export,
35            "is-readonly": self.handle_is_readonly,
36        })
37
38    def validate_proto_version(self):
39        return (self.proto_version == (1, 0))
40
41    async def dispatch_message(self, msg):
42        try:
43            return await super().dispatch_message(msg)
44        except:
45            raise
46
47    async def handle_test_pr(self, request):
48        '''Finds the PR value corresponding to the request. If not found, returns None and doesn't insert a new value'''
49        version = request["version"]
50        pkgarch = request["pkgarch"]
51        checksum = request["checksum"]
52        history = request["history"]
53
54        value = self.server.table.find_value(version, pkgarch, checksum, history)
55        return {"value": value}
56
57    async def handle_test_package(self, request):
58        '''Tells whether there are entries for (version, pkgarch) in the db. Returns True or False'''
59        version = request["version"]
60        pkgarch = request["pkgarch"]
61
62        value = self.server.table.test_package(version, pkgarch)
63        return {"value": value}
64
65    async def handle_max_package_pr(self, request):
66        '''Finds the greatest PR value for (version, pkgarch) in the db. Returns None if no entry was found'''
67        version = request["version"]
68        pkgarch = request["pkgarch"]
69
70        value = self.server.table.find_package_max_value(version, pkgarch)
71        return {"value": value}
72
73    async def handle_get_pr(self, request):
74        version = request["version"]
75        pkgarch = request["pkgarch"]
76        checksum = request["checksum"]
77        history = request["history"]
78
79        if self.upstream_client is None:
80            value = self.server.table.get_value(version, pkgarch, checksum, history)
81            return {"value": value}
82
83        # We have an upstream server.
84        # Check whether the local server already knows the requested configuration.
85        # If the configuration is a new one, the generated value we will add will
86        # depend on what's on the upstream server. That's why we're calling find_value()
87        # instead of get_value() directly.
88
89        value = self.server.table.find_value(version, pkgarch, checksum, history)
90        upstream_max = await self.upstream_client.max_package_pr(version, pkgarch)
91
92        if value is not None:
93
94            # The configuration is already known locally.
95
96            if history:
97                value = self.server.table.get_value(version, pkgarch, checksum, history)
98            else:
99                existing_value = value
100                # In "no history", we need to make sure the value doesn't decrease
101                # and is at least greater than the maximum upstream value
102                # and the maximum local value
103
104                local_max = self.server.table.find_package_max_value(version, pkgarch)
105                if revision_smaller(value, local_max):
106                    value = increase_revision(local_max)
107
108                if revision_smaller(value, upstream_max):
109                    # Ask upstream whether it knows the checksum
110                    upstream_value = await self.upstream_client.test_pr(version, pkgarch, checksum)
111                    if upstream_value is None:
112                        # Upstream doesn't have our checksum, let create a new one
113                        value = upstream_max + ".0"
114                    else:
115                        # Fine to take the same value as upstream
116                        value = upstream_max
117
118                if not value == existing_value and not self.server.read_only:
119                    self.server.table.store_value(version, pkgarch, checksum, value)
120
121            return {"value": value}
122
123        # The configuration is a new one for the local server
124        # Let's ask the upstream server whether it knows it
125
126        known_upstream = await self.upstream_client.test_package(version, pkgarch)
127
128        if not known_upstream:
129
130            # The package is not known upstream, must be a local-only package
131            # Let's compute the PR number using the local-only method
132
133            value = self.server.table.get_value(version, pkgarch, checksum, history)
134            return {"value": value}
135
136        # The package is known upstream, let's ask the upstream server
137        # whether it knows our new output hash
138
139        value = await self.upstream_client.test_pr(version, pkgarch, checksum)
140
141        if value is not None:
142
143            # Upstream knows this output hash, let's store it and use it too.
144
145            if not self.server.read_only:
146                self.server.table.store_value(version, pkgarch, checksum, value)
147            # If the local server is read only, won't be able to store the new
148            # value in the database and will have to keep asking the upstream server
149            return {"value": value}
150
151        # The output hash doesn't exist upstream, get the most recent number from upstream (x)
152        # Then, we want to have a new PR value for the local server: x.y
153
154        upstream_max = await self.upstream_client.max_package_pr(version, pkgarch)
155        # Here we know that the package is known upstream, so upstream_max can't be None
156        subvalue = self.server.table.find_new_subvalue(version, pkgarch, upstream_max)
157
158        if not self.server.read_only:
159            self.server.table.store_value(version, pkgarch, checksum, subvalue)
160
161        return {"value": subvalue}
162
163    async def process_requests(self):
164        if self.server.upstream is not None:
165            self.upstream_client = await create_async_client(self.server.upstream)
166        else:
167            self.upstream_client = None
168
169        try:
170            await super().process_requests()
171        finally:
172            if self.upstream_client is not None:
173                await self.upstream_client.close()
174
175    async def handle_import_one(self, request):
176        response = None
177        if not self.server.read_only:
178            version = request["version"]
179            pkgarch = request["pkgarch"]
180            checksum = request["checksum"]
181            value = request["value"]
182
183            value = self.server.table.importone(version, pkgarch, checksum, value)
184            if value is not None:
185                response = {"value": value}
186
187        return response
188
189    async def handle_export(self, request):
190        version = request["version"]
191        pkgarch = request["pkgarch"]
192        checksum = request["checksum"]
193        colinfo = request["colinfo"]
194        history = request["history"]
195
196        try:
197            (metainfo, datainfo) = self.server.table.export(version, pkgarch, checksum, colinfo, history)
198        except sqlite3.Error as exc:
199            self.logger.error(str(exc))
200            metainfo = datainfo = None
201
202        return {"metainfo": metainfo, "datainfo": datainfo}
203
204    async def handle_is_readonly(self, request):
205        return {"readonly": self.server.read_only}
206
207class PRServer(bb.asyncrpc.AsyncServer):
208    def __init__(self, dbfile, read_only=False, upstream=None):
209        super().__init__(logger)
210        self.dbfile = dbfile
211        self.table = None
212        self.read_only = read_only
213        self.upstream = upstream
214
215    def accept_client(self, socket):
216        return PRServerClient(socket, self)
217
218    def start(self):
219        tasks = super().start()
220        self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only)
221        self.table = self.db["PRMAIN"]
222
223        self.logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" %
224                     (self.dbfile, self.address, str(os.getpid())))
225
226        if self.upstream is not None:
227            self.logger.info("And upstream PRServer: %s " % (self.upstream))
228
229        return tasks
230
231    async def stop(self):
232        self.db.disconnect()
233        await super().stop()
234
235class PRServSingleton(object):
236    def __init__(self, dbfile, logfile, host, port, upstream):
237        self.dbfile = dbfile
238        self.logfile = logfile
239        self.host = host
240        self.port = port
241        self.upstream = upstream
242
243    def start(self):
244        self.prserv = PRServer(self.dbfile, upstream=self.upstream)
245        self.prserv.start_tcp_server(socket.gethostbyname(self.host), self.port)
246        self.process = self.prserv.serve_as_process(log_level=logging.WARNING)
247
248        if not self.prserv.address:
249            raise PRServiceConfigError
250        if not self.port:
251            self.port = int(self.prserv.address.rsplit(":", 1)[1])
252
253def run_as_daemon(func, pidfile, logfile):
254    """
255    See Advanced Programming in the UNIX, Sec 13.3
256    """
257    try:
258        pid = os.fork()
259        if pid > 0:
260            os.waitpid(pid, 0)
261            #parent return instead of exit to give control
262            return pid
263    except OSError as e:
264        raise Exception("%s [%d]" % (e.strerror, e.errno))
265
266    os.setsid()
267    """
268    fork again to make sure the daemon is not session leader,
269    which prevents it from acquiring controlling terminal
270    """
271    try:
272        pid = os.fork()
273        if pid > 0: #parent
274            os._exit(0)
275    except OSError as e:
276        raise Exception("%s [%d]" % (e.strerror, e.errno))
277
278    os.chdir("/")
279
280    sys.stdout.flush()
281    sys.stderr.flush()
282
283    # We could be called from a python thread with io.StringIO as
284    # stdout/stderr or it could be 'real' unix fd forking where we need
285    # to physically close the fds to prevent the program launching us from
286    # potentially hanging on a pipe. Handle both cases.
287    si = open("/dev/null", "r")
288    try:
289        os.dup2(si.fileno(), sys.stdin.fileno())
290    except (AttributeError, io.UnsupportedOperation):
291        sys.stdin = si
292    so = open(logfile, "a+")
293    try:
294        os.dup2(so.fileno(), sys.stdout.fileno())
295    except (AttributeError, io.UnsupportedOperation):
296        sys.stdout = so
297    try:
298        os.dup2(so.fileno(), sys.stderr.fileno())
299    except (AttributeError, io.UnsupportedOperation):
300        sys.stderr = so
301
302    # Clear out all log handlers prior to the fork() to avoid calling
303    # event handlers not part of the PRserver
304    for logger_iter in logging.Logger.manager.loggerDict.keys():
305        logging.getLogger(logger_iter).handlers = []
306
307    # Ensure logging makes it to the logfile
308    streamhandler = logging.StreamHandler()
309    streamhandler.setLevel(logging.DEBUG)
310    formatter = bb.msg.BBLogFormatter("%(levelname)s: %(message)s")
311    streamhandler.setFormatter(formatter)
312    logger.addHandler(streamhandler)
313
314    # write pidfile
315    pid = str(os.getpid())
316    with open(pidfile, "w") as pf:
317        pf.write("%s\n" % pid)
318
319    func()
320    os.remove(pidfile)
321    os._exit(0)
322
323def start_daemon(dbfile, host, port, logfile, read_only=False, upstream=None):
324    ip = socket.gethostbyname(host)
325    pidfile = PIDPREFIX % (ip, port)
326    try:
327        with open(pidfile) as pf:
328            pid = int(pf.readline().strip())
329    except IOError:
330        pid = None
331
332    if pid:
333        sys.stderr.write("pidfile %s already exist. Daemon already running?\n"
334                            % pidfile)
335        return 1
336
337    dbfile = os.path.abspath(dbfile)
338    def daemon_main():
339        server = PRServer(dbfile, read_only=read_only, upstream=upstream)
340        server.start_tcp_server(ip, port)
341        server.serve_forever()
342
343    run_as_daemon(daemon_main, pidfile, os.path.abspath(logfile))
344    return 0
345
346def stop_daemon(host, port):
347    import glob
348    ip = socket.gethostbyname(host)
349    pidfile = PIDPREFIX % (ip, port)
350    try:
351        with open(pidfile) as pf:
352            pid = int(pf.readline().strip())
353    except IOError:
354        pid = None
355
356    if not pid:
357        # when server starts at port=0 (i.e. localhost:0), server actually takes another port,
358        # so at least advise the user which ports the corresponding server is listening
359        ports = []
360        portstr = ""
361        for pf in glob.glob(PIDPREFIX % (ip, "*")):
362            bn = os.path.basename(pf)
363            root, _ = os.path.splitext(bn)
364            ports.append(root.split("_")[-1])
365        if len(ports):
366            portstr = "Wrong port? Other ports listening at %s: %s" % (host, " ".join(ports))
367
368        sys.stderr.write("pidfile %s does not exist. Daemon not running? %s\n"
369                         % (pidfile, portstr))
370        return 1
371
372    try:
373        if is_running(pid):
374            print("Sending SIGTERM to pr-server.")
375            os.kill(pid, signal.SIGTERM)
376            time.sleep(0.1)
377
378        try:
379            os.remove(pidfile)
380        except FileNotFoundError:
381            # The PID file might have been removed by the exiting process
382            pass
383
384    except OSError as e:
385        err = str(e)
386        if err.find("No such process") <= 0:
387            raise e
388
389    return 0
390
391def is_running(pid):
392    try:
393        os.kill(pid, 0)
394    except OSError as err:
395        if err.errno == errno.ESRCH:
396            return False
397    return True
398
399def is_local_special(host, port):
400    if (host == "localhost" or host == "127.0.0.1") and not port:
401        return True
402    else:
403        return False
404
405class PRServiceConfigError(Exception):
406    pass
407
408def auto_start(d):
409    global singleton
410
411    host_params = list(filter(None, (d.getVar("PRSERV_HOST") or "").split(":")))
412    if not host_params:
413        # Shutdown any existing PR Server
414        auto_shutdown()
415        return None
416
417    if len(host_params) != 2:
418        # Shutdown any existing PR Server
419        auto_shutdown()
420        logger.critical("\n".join(["PRSERV_HOST: incorrect format",
421                'Usage: PRSERV_HOST = "<hostname>:<port>"']))
422        raise PRServiceConfigError
423
424    host = host_params[0].strip().lower()
425    port = int(host_params[1])
426
427    upstream = d.getVar("PRSERV_UPSTREAM") or None
428
429    if is_local_special(host, port):
430        import bb.utils
431        cachedir = (d.getVar("PERSISTENT_DIR") or d.getVar("CACHE"))
432        if not cachedir:
433            logger.critical("Please set the 'PERSISTENT_DIR' or 'CACHE' variable")
434            raise PRServiceConfigError
435        dbfile = os.path.join(cachedir, "prserv.sqlite3")
436        logfile = os.path.join(cachedir, "prserv.log")
437        if singleton:
438            if singleton.dbfile != dbfile:
439               # Shutdown any existing PR Server as doesn't match config
440               auto_shutdown()
441        if not singleton:
442            bb.utils.mkdirhier(cachedir)
443            singleton = PRServSingleton(os.path.abspath(dbfile), os.path.abspath(logfile), host, port, upstream)
444            singleton.start()
445    if singleton:
446        host = singleton.host
447        port = singleton.port
448
449    try:
450        ping(host, port)
451        return str(host) + ":" + str(port)
452
453    except Exception:
454        logger.critical("PRservice %s:%d not available" % (host, port))
455        raise PRServiceConfigError
456
457def auto_shutdown():
458    global singleton
459    if singleton and singleton.process:
460        singleton.process.terminate()
461        singleton.process.join()
462        singleton = None
463
464def ping(host, port):
465    from . import client
466
467    with client.PRClient() as conn:
468        conn.connect_tcp(host, port)
469        return conn.ping()
470
471def connect(host, port):
472    from . import client
473
474    global singleton
475
476    if host.strip().lower() == "localhost" and not port:
477        host = "localhost"
478        port = singleton.port
479
480    conn = client.PRClient()
481    conn.connect_tcp(host, port)
482    return conn
483