1# Copyright (C) 2018-2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6import asyncio
7from contextlib import closing
8import itertools
9import json
10from collections import namedtuple
11from urllib.parse import urlparse
12from bb.asyncrpc.client import parse_address, ADDR_TYPE_UNIX, ADDR_TYPE_WS
13
14User = namedtuple("User", ("username", "permissions"))
15
16
17def create_server(
18    addr,
19    dbname,
20    *,
21    sync=True,
22    upstream=None,
23    read_only=False,
24    db_username=None,
25    db_password=None,
26    anon_perms=None,
27    admin_username=None,
28    admin_password=None,
29    reuseport=False,
30):
31    def sqlite_engine():
32        from .sqlite import DatabaseEngine
33
34        return DatabaseEngine(dbname, sync)
35
36    def sqlalchemy_engine():
37        from .sqlalchemy import DatabaseEngine
38
39        return DatabaseEngine(dbname, db_username, db_password)
40
41    from . import server
42
43    if "://" in dbname:
44        db_engine = sqlalchemy_engine()
45    else:
46        db_engine = sqlite_engine()
47
48    if anon_perms is None:
49        anon_perms = server.DEFAULT_ANON_PERMS
50
51    s = server.Server(
52        db_engine,
53        upstream=upstream,
54        read_only=read_only,
55        anon_perms=anon_perms,
56        admin_username=admin_username,
57        admin_password=admin_password,
58    )
59
60    (typ, a) = parse_address(addr)
61    if typ == ADDR_TYPE_UNIX:
62        s.start_unix_server(*a)
63    elif typ == ADDR_TYPE_WS:
64        url = urlparse(a[0])
65        s.start_websocket_server(url.hostname, url.port, reuseport=reuseport)
66    else:
67        s.start_tcp_server(*a, reuseport=reuseport)
68
69    return s
70
71
72def create_client(addr, username=None, password=None):
73    from . import client
74
75    c = client.Client(username, password)
76
77    try:
78        (typ, a) = parse_address(addr)
79        if typ == ADDR_TYPE_UNIX:
80            c.connect_unix(*a)
81        elif typ == ADDR_TYPE_WS:
82            c.connect_websocket(*a)
83        else:
84            c.connect_tcp(*a)
85        return c
86    except Exception as e:
87        c.close()
88        raise e
89
90
91async def create_async_client(addr, username=None, password=None):
92    from . import client
93
94    c = client.AsyncClient(username, password)
95
96    try:
97        (typ, a) = parse_address(addr)
98        if typ == ADDR_TYPE_UNIX:
99            await c.connect_unix(*a)
100        elif typ == ADDR_TYPE_WS:
101            await c.connect_websocket(*a)
102        else:
103            await c.connect_tcp(*a)
104
105        return c
106    except Exception as e:
107        await c.close()
108        raise e
109