1#!/bin/env python3
2# SPDX-License-Identifier: GPL-2.0
3# -*- coding: utf-8 -*-
4#
5# Copyright (c) 2017 Benjamin Tissoires <benjamin.tissoires@gmail.com>
6# Copyright (c) 2017 Red Hat, Inc.
7
8import libevdev
9import os
10import pytest
11import time
12
13import logging
14
15from hidtools.device.base_device import BaseDevice, EvdevMatch, SysfsFile
16from pathlib import Path
17from typing import Final
18
19logger = logging.getLogger("hidtools.test.base")
20
21# application to matches
22application_matches: Final = {
23    # pyright: ignore
24    "Accelerometer": EvdevMatch(
25        req_properties=[
26            libevdev.INPUT_PROP_ACCELEROMETER,
27        ]
28    ),
29    "Game Pad": EvdevMatch(  # in systemd, this is a lot more complex, but that will do
30        requires=[
31            libevdev.EV_ABS.ABS_X,
32            libevdev.EV_ABS.ABS_Y,
33            libevdev.EV_ABS.ABS_RX,
34            libevdev.EV_ABS.ABS_RY,
35            libevdev.EV_KEY.BTN_START,
36        ],
37        excl_properties=[
38            libevdev.INPUT_PROP_ACCELEROMETER,
39        ],
40    ),
41    "Joystick": EvdevMatch(  # in systemd, this is a lot more complex, but that will do
42        requires=[
43            libevdev.EV_ABS.ABS_RX,
44            libevdev.EV_ABS.ABS_RY,
45            libevdev.EV_KEY.BTN_START,
46        ],
47        excl_properties=[
48            libevdev.INPUT_PROP_ACCELEROMETER,
49        ],
50    ),
51    "Key": EvdevMatch(
52        requires=[
53            libevdev.EV_KEY.KEY_A,
54        ],
55        excl_properties=[
56            libevdev.INPUT_PROP_ACCELEROMETER,
57            libevdev.INPUT_PROP_DIRECT,
58            libevdev.INPUT_PROP_POINTER,
59        ],
60    ),
61    "Mouse": EvdevMatch(
62        requires=[
63            libevdev.EV_REL.REL_X,
64            libevdev.EV_REL.REL_Y,
65            libevdev.EV_KEY.BTN_LEFT,
66        ],
67        excl_properties=[
68            libevdev.INPUT_PROP_ACCELEROMETER,
69        ],
70    ),
71    "Pad": EvdevMatch(
72        requires=[
73            libevdev.EV_KEY.BTN_0,
74        ],
75        excludes=[
76            libevdev.EV_KEY.BTN_TOOL_PEN,
77            libevdev.EV_KEY.BTN_TOUCH,
78            libevdev.EV_ABS.ABS_DISTANCE,
79        ],
80        excl_properties=[
81            libevdev.INPUT_PROP_ACCELEROMETER,
82        ],
83    ),
84    "Pen": EvdevMatch(
85        requires=[
86            libevdev.EV_KEY.BTN_STYLUS,
87            libevdev.EV_ABS.ABS_X,
88            libevdev.EV_ABS.ABS_Y,
89        ],
90        excl_properties=[
91            libevdev.INPUT_PROP_ACCELEROMETER,
92        ],
93    ),
94    "Stylus": EvdevMatch(
95        requires=[
96            libevdev.EV_KEY.BTN_STYLUS,
97            libevdev.EV_ABS.ABS_X,
98            libevdev.EV_ABS.ABS_Y,
99        ],
100        excl_properties=[
101            libevdev.INPUT_PROP_ACCELEROMETER,
102        ],
103    ),
104    "Touch Pad": EvdevMatch(
105        requires=[
106            libevdev.EV_KEY.BTN_LEFT,
107            libevdev.EV_ABS.ABS_X,
108            libevdev.EV_ABS.ABS_Y,
109        ],
110        excludes=[libevdev.EV_KEY.BTN_TOOL_PEN, libevdev.EV_KEY.BTN_STYLUS],
111        req_properties=[
112            libevdev.INPUT_PROP_POINTER,
113        ],
114        excl_properties=[
115            libevdev.INPUT_PROP_ACCELEROMETER,
116        ],
117    ),
118    "Touch Screen": EvdevMatch(
119        requires=[
120            libevdev.EV_KEY.BTN_TOUCH,
121            libevdev.EV_ABS.ABS_X,
122            libevdev.EV_ABS.ABS_Y,
123        ],
124        excludes=[libevdev.EV_KEY.BTN_TOOL_PEN, libevdev.EV_KEY.BTN_STYLUS],
125        req_properties=[
126            libevdev.INPUT_PROP_DIRECT,
127        ],
128        excl_properties=[
129            libevdev.INPUT_PROP_ACCELEROMETER,
130        ],
131    ),
132}
133
134
135class UHIDTestDevice(BaseDevice):
136    def __init__(self, name, application, rdesc_str=None, rdesc=None, input_info=None):
137        super().__init__(name, application, rdesc_str, rdesc, input_info)
138        self.application_matches = application_matches
139        if name is None:
140            name = f"uhid test {self.__class__.__name__}"
141        if not name.startswith("uhid test "):
142            name = "uhid test " + self.name
143        self.name = name
144
145
146class BaseTestCase:
147    class TestUhid(object):
148        syn_event = libevdev.InputEvent(libevdev.EV_SYN.SYN_REPORT)  # type: ignore
149        key_event = libevdev.InputEvent(libevdev.EV_KEY)  # type: ignore
150        abs_event = libevdev.InputEvent(libevdev.EV_ABS)  # type: ignore
151        rel_event = libevdev.InputEvent(libevdev.EV_REL)  # type: ignore
152        msc_event = libevdev.InputEvent(libevdev.EV_MSC.MSC_SCAN)  # type: ignore
153
154        # List of kernel modules to load before starting the test
155        # if any module is not available (not compiled), the test will skip.
156        # Each element is a tuple '(kernel driver name, kernel module)',
157        # for example ("playstation", "hid-playstation")
158        kernel_modules = []
159
160        def assertInputEventsIn(self, expected_events, effective_events):
161            effective_events = effective_events.copy()
162            for ev in expected_events:
163                assert ev in effective_events
164                effective_events.remove(ev)
165            return effective_events
166
167        def assertInputEvents(self, expected_events, effective_events):
168            remaining = self.assertInputEventsIn(expected_events, effective_events)
169            assert remaining == []
170
171        @classmethod
172        def debug_reports(cls, reports, uhdev=None, events=None):
173            data = [" ".join([f"{v:02x}" for v in r]) for r in reports]
174
175            if uhdev is not None:
176                human_data = [
177                    uhdev.parsed_rdesc.format_report(r, split_lines=True)
178                    for r in reports
179                ]
180                try:
181                    human_data = [
182                        f'\n\t       {" " * h.index("/")}'.join(h.split("\n"))
183                        for h in human_data
184                    ]
185                except ValueError:
186                    # '/' not found: not a numbered report
187                    human_data = ["\n\t      ".join(h.split("\n")) for h in human_data]
188                data = [f"{d}\n\t ====> {h}" for d, h in zip(data, human_data)]
189
190            reports = data
191
192            if len(reports) == 1:
193                print("sending 1 report:")
194            else:
195                print(f"sending {len(reports)} reports:")
196            for report in reports:
197                print("\t", report)
198
199            if events is not None:
200                print("events received:", events)
201
202        def create_device(self):
203            raise Exception("please reimplement me in subclasses")
204
205        def _load_kernel_module(self, kernel_driver, kernel_module):
206            sysfs_path = Path("/sys/bus/hid/drivers")
207            if kernel_driver is not None:
208                sysfs_path /= kernel_driver
209            else:
210                # special case for when testing all available modules:
211                # we don't know beforehand the name of the module from modinfo
212                sysfs_path = Path("/sys/module") / kernel_module.replace("-", "_")
213            if not sysfs_path.exists():
214                import subprocess
215
216                ret = subprocess.run(["/usr/sbin/modprobe", kernel_module])
217                if ret.returncode != 0:
218                    pytest.skip(
219                        f"module {kernel_module} could not be loaded, skipping the test"
220                    )
221
222        @pytest.fixture()
223        def load_kernel_module(self):
224            for kernel_driver, kernel_module in self.kernel_modules:
225                self._load_kernel_module(kernel_driver, kernel_module)
226            yield
227
228        @pytest.fixture()
229        def new_uhdev(self, load_kernel_module):
230            return self.create_device()
231
232        def assertName(self, uhdev):
233            evdev = uhdev.get_evdev()
234            assert uhdev.name in evdev.name
235
236        @pytest.fixture(autouse=True)
237        def context(self, new_uhdev, request):
238            try:
239                with HIDTestUdevRule.instance():
240                    with new_uhdev as self.uhdev:
241                        skip_cond = request.node.get_closest_marker("skip_if_uhdev")
242                        if skip_cond:
243                            test, message, *rest = skip_cond.args
244
245                            if test(self.uhdev):
246                                pytest.skip(message)
247
248                        self.uhdev.create_kernel_device()
249                        now = time.time()
250                        while not self.uhdev.is_ready() and time.time() - now < 5:
251                            self.uhdev.dispatch(1)
252                        if self.uhdev.get_evdev() is None:
253                            logger.warning(
254                                f"available list of input nodes: (default application is '{self.uhdev.application}')"
255                            )
256                            logger.warning(self.uhdev.input_nodes)
257                        yield
258                        self.uhdev = None
259            except PermissionError:
260                pytest.skip("Insufficient permissions, run me as root")
261
262        @pytest.fixture(autouse=True)
263        def check_taint(self):
264            # we are abusing SysfsFile here, it's in /proc, but meh
265            taint_file = SysfsFile("/proc/sys/kernel/tainted")
266            taint = taint_file.int_value
267
268            yield
269
270            assert taint_file.int_value == taint
271
272        def test_creation(self):
273            """Make sure the device gets processed by the kernel and creates
274            the expected application input node.
275
276            If this fail, there is something wrong in the device report
277            descriptors."""
278            uhdev = self.uhdev
279            assert uhdev is not None
280            assert uhdev.get_evdev() is not None
281            self.assertName(uhdev)
282            assert len(uhdev.next_sync_events()) == 0
283            assert uhdev.get_evdev() is not None
284
285
286class HIDTestUdevRule(object):
287    _instance = None
288    """
289    A context-manager compatible class that sets up our udev rules file and
290    deletes it on context exit.
291
292    This class is tailored to our test setup: it only sets up the udev rule
293    on the **second** context and it cleans it up again on the last context
294    removed. This matches the expected pytest setup: we enter a context for
295    the session once, then once for each test (the first of which will
296    trigger the udev rule) and once the last test exited and the session
297    exited, we clean up after ourselves.
298    """
299
300    def __init__(self):
301        self.refs = 0
302        self.rulesfile = None
303
304    def __enter__(self):
305        self.refs += 1
306        if self.refs == 2 and self.rulesfile is None:
307            self.create_udev_rule()
308            self.reload_udev_rules()
309
310    def __exit__(self, exc_type, exc_value, traceback):
311        self.refs -= 1
312        if self.refs == 0 and self.rulesfile:
313            os.remove(self.rulesfile.name)
314            self.reload_udev_rules()
315
316    def reload_udev_rules(self):
317        import subprocess
318
319        subprocess.run("udevadm control --reload-rules".split())
320        subprocess.run("systemd-hwdb update".split())
321
322    def create_udev_rule(self):
323        import tempfile
324
325        os.makedirs("/run/udev/rules.d", exist_ok=True)
326        with tempfile.NamedTemporaryFile(
327            prefix="91-uhid-test-device-REMOVEME-",
328            suffix=".rules",
329            mode="w+",
330            dir="/run/udev/rules.d",
331            delete=False,
332        ) as f:
333            f.write(
334                'KERNELS=="*input*", ATTRS{name}=="*uhid test *", ENV{LIBINPUT_IGNORE_DEVICE}="1"\n'
335            )
336            f.write(
337                'KERNELS=="*input*", ATTRS{name}=="*uhid test * System Multi Axis", ENV{ID_INPUT_TOUCHSCREEN}="", ENV{ID_INPUT_SYSTEM_MULTIAXIS}="1"\n'
338            )
339            self.rulesfile = f
340
341    @classmethod
342    def instance(cls):
343        if not cls._instance:
344            cls._instance = HIDTestUdevRule()
345        return cls._instance
346