xref: /openbmc/qemu/python/tests/protocol.py (revision 1034cd169cfa55d7abf68a9de468e14ae15fd9f5)
1import asyncio
2from contextlib import contextmanager
3import os
4import socket
5from tempfile import TemporaryDirectory
6
7import avocado
8
9from qemu.qmp import ConnectError, Runstate
10from qemu.qmp.protocol import AsyncProtocol, StateError
11
12
13class NullProtocol(AsyncProtocol[None]):
14    """
15    NullProtocol is a test mockup of an AsyncProtocol implementation.
16
17    It adds a fake_session instance variable that enables a code path
18    that bypasses the actual connection logic, but still allows the
19    reader/writers to start.
20
21    Because the message type is defined as None, an asyncio.Event named
22    'trigger_input' is created that prohibits the reader from
23    incessantly being able to yield None; this event can be poked to
24    simulate an incoming message.
25
26    For testing symmetry with do_recv, an interface is added to "send" a
27    Null message.
28
29    For testing purposes, a "simulate_disconnection" method is also
30    added which allows us to trigger a bottom half disconnect without
31    injecting any real errors into the reader/writer loops; in essence
32    it performs exactly half of what disconnect() normally does.
33    """
34    def __init__(self, name=None):
35        self.fake_session = False
36        self.trigger_input: asyncio.Event
37        super().__init__(name)
38
39    async def _establish_session(self):
40        self.trigger_input = asyncio.Event()
41        await super()._establish_session()
42
43    async def _do_start_server(self, address, ssl=None):
44        if self.fake_session:
45            self._accepted = asyncio.Event()
46            self._set_state(Runstate.CONNECTING)
47            await asyncio.sleep(0)
48        else:
49            await super()._do_start_server(address, ssl)
50
51    async def _do_accept(self):
52        if self.fake_session:
53            self._accepted = None
54        else:
55            await super()._do_accept()
56
57    async def _do_connect(self, address, ssl=None):
58        if self.fake_session:
59            self._set_state(Runstate.CONNECTING)
60            await asyncio.sleep(0)
61        else:
62            await super()._do_connect(address, ssl)
63
64    async def _do_recv(self) -> None:
65        await self.trigger_input.wait()
66        self.trigger_input.clear()
67
68    def _do_send(self, msg: None) -> None:
69        pass
70
71    async def send_msg(self) -> None:
72        await self._outgoing.put(None)
73
74    async def simulate_disconnect(self) -> None:
75        """
76        Simulates a bottom-half disconnect.
77
78        This method schedules a disconnection but does not wait for it
79        to complete. This is used to put the loop into the DISCONNECTING
80        state without fully quiescing it back to IDLE. This is normally
81        something you cannot coax AsyncProtocol to do on purpose, but it
82        will be similar to what happens with an unhandled Exception in
83        the reader/writer.
84
85        Under normal circumstances, the library design requires you to
86        await on disconnect(), which awaits the disconnect task and
87        returns bottom half errors as a pre-condition to allowing the
88        loop to return back to IDLE.
89        """
90        self._schedule_disconnect()
91
92
93class LineProtocol(AsyncProtocol[str]):
94    def __init__(self, name=None):
95        super().__init__(name)
96        self.rx_history = []
97
98    async def _do_recv(self) -> str:
99        raw = await self._readline()
100        msg = raw.decode()
101        self.rx_history.append(msg)
102        return msg
103
104    def _do_send(self, msg: str) -> None:
105        assert self._writer is not None
106        self._writer.write(msg.encode() + b'\n')
107
108    async def send_msg(self, msg: str) -> None:
109        await self._outgoing.put(msg)
110
111
112def run_as_task(coro, allow_cancellation=False):
113    """
114    Run a given coroutine as a task.
115
116    Optionally, wrap it in a try..except block that allows this
117    coroutine to be canceled gracefully.
118    """
119    async def _runner():
120        try:
121            await coro
122        except asyncio.CancelledError:
123            if allow_cancellation:
124                return
125            raise
126    return asyncio.create_task(_runner())
127
128
129@contextmanager
130def jammed_socket():
131    """
132    Opens up a random unused TCP port on localhost, then jams it.
133    """
134    socks = []
135
136    try:
137        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
138        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
139        sock.bind(('127.0.0.1', 0))
140        sock.listen(1)
141        address = sock.getsockname()
142
143        socks.append(sock)
144
145        # I don't *fully* understand why, but it takes *two* un-accepted
146        # connections to start jamming the socket.
147        for _ in range(2):
148            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
149            sock.connect(address)
150            socks.append(sock)
151
152        yield address
153
154    finally:
155        for sock in socks:
156            sock.close()
157
158
159class Smoke(avocado.Test):
160
161    def setUp(self):
162        self.proto = NullProtocol()
163
164    def test__repr__(self):
165        self.assertEqual(
166            repr(self.proto),
167            "<NullProtocol runstate=IDLE>"
168        )
169
170    def testRunstate(self):
171        self.assertEqual(
172            self.proto.runstate,
173            Runstate.IDLE
174        )
175
176    def testDefaultName(self):
177        self.assertEqual(
178            self.proto.name,
179            None
180        )
181
182    def testLogger(self):
183        self.assertEqual(
184            self.proto.logger.name,
185            'qemu.qmp.protocol'
186        )
187
188    def testName(self):
189        self.proto = NullProtocol('Steve')
190
191        self.assertEqual(
192            self.proto.name,
193            'Steve'
194        )
195
196        self.assertEqual(
197            self.proto.logger.name,
198            'qemu.qmp.protocol.Steve'
199        )
200
201        self.assertEqual(
202            repr(self.proto),
203            "<NullProtocol name='Steve' runstate=IDLE>"
204        )
205
206
207class TestBase(avocado.Test):
208
209    def setUp(self):
210        self.proto = NullProtocol(type(self).__name__)
211        self.assertEqual(self.proto.runstate, Runstate.IDLE)
212        self.runstate_watcher = None
213
214    def tearDown(self):
215        self.assertEqual(self.proto.runstate, Runstate.IDLE)
216
217    async def _asyncSetUp(self):
218        pass
219
220    async def _asyncTearDown(self):
221        if self.runstate_watcher:
222            await self.runstate_watcher
223
224    @staticmethod
225    def async_test(async_test_method):
226        """
227        Decorator; adds SetUp and TearDown to async tests.
228        """
229        async def _wrapper(self, *args, **kwargs):
230            loop = asyncio.get_running_loop()
231            loop.set_debug(True)
232
233            await self._asyncSetUp()
234            await async_test_method(self, *args, **kwargs)
235            await self._asyncTearDown()
236
237        return _wrapper
238
239    # Definitions
240
241    # The states we expect a "bad" connect/accept attempt to transition through
242    BAD_CONNECTION_STATES = (
243        Runstate.CONNECTING,
244        Runstate.DISCONNECTING,
245        Runstate.IDLE,
246    )
247
248    # The states we expect a "good" session to transition through
249    GOOD_CONNECTION_STATES = (
250        Runstate.CONNECTING,
251        Runstate.RUNNING,
252        Runstate.DISCONNECTING,
253        Runstate.IDLE,
254    )
255
256    # Helpers
257
258    async def _watch_runstates(self, *states):
259        """
260        This launches a task alongside (most) tests below to confirm that
261        the sequence of runstate changes that occur is exactly as
262        anticipated.
263        """
264        async def _watcher():
265            for state in states:
266                new_state = await self.proto.runstate_changed()
267                self.assertEqual(
268                    new_state,
269                    state,
270                    msg=f"Expected state '{state.name}'",
271                )
272
273        self.runstate_watcher = asyncio.create_task(_watcher())
274        # Kick the loop and force the task to block on the event.
275        await asyncio.sleep(0)
276
277
278class State(TestBase):
279
280    @TestBase.async_test
281    async def testSuperfluousDisconnect(self):
282        """
283        Test calling disconnect() while already disconnected.
284        """
285        await self._watch_runstates(
286            Runstate.DISCONNECTING,
287            Runstate.IDLE,
288        )
289        await self.proto.disconnect()
290
291
292class Connect(TestBase):
293    """
294    Tests primarily related to calling Connect().
295    """
296    async def _bad_connection(self, family: str):
297        assert family in ('INET', 'UNIX')
298
299        if family == 'INET':
300            await self.proto.connect(('127.0.0.1', 0))
301        elif family == 'UNIX':
302            await self.proto.connect('/dev/null')
303
304    async def _hanging_connection(self):
305        with jammed_socket() as addr:
306            await self.proto.connect(addr)
307
308    async def _bad_connection_test(self, family: str):
309        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
310
311        with self.assertRaises(ConnectError) as context:
312            await self._bad_connection(family)
313
314        self.assertIsInstance(context.exception.exc, OSError)
315        self.assertEqual(
316            context.exception.error_message,
317            "Failed to establish connection"
318        )
319
320    @TestBase.async_test
321    async def testBadINET(self):
322        """
323        Test an immediately rejected call to an IP target.
324        """
325        await self._bad_connection_test('INET')
326
327    @TestBase.async_test
328    async def testBadUNIX(self):
329        """
330        Test an immediately rejected call to a UNIX socket target.
331        """
332        await self._bad_connection_test('UNIX')
333
334    @TestBase.async_test
335    async def testCancellation(self):
336        """
337        Test what happens when a connection attempt is aborted.
338        """
339        # Note that accept() cannot be cancelled outright, as it isn't a task.
340        # However, we can wrap it in a task and cancel *that*.
341        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
342        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
343
344        state = await self.proto.runstate_changed()
345        self.assertEqual(state, Runstate.CONNECTING)
346
347        # This is insider baseball, but the connection attempt has
348        # yielded *just* before the actual connection attempt, so kick
349        # the loop to make sure it's truly wedged.
350        await asyncio.sleep(0)
351
352        task.cancel()
353        await task
354
355    @TestBase.async_test
356    async def testTimeout(self):
357        """
358        Test what happens when a connection attempt times out.
359        """
360        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
361        task = run_as_task(self._hanging_connection())
362
363        # More insider baseball: to improve the speed of this test while
364        # guaranteeing that the connection even gets a chance to start,
365        # verify that the connection hangs *first*, then await the
366        # result of the task with a nearly-zero timeout.
367
368        state = await self.proto.runstate_changed()
369        self.assertEqual(state, Runstate.CONNECTING)
370        await asyncio.sleep(0)
371
372        with self.assertRaises(asyncio.TimeoutError):
373            await asyncio.wait_for(task, timeout=0)
374
375    @TestBase.async_test
376    async def testRequire(self):
377        """
378        Test what happens when a connection attempt is made while CONNECTING.
379        """
380        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
381        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
382
383        state = await self.proto.runstate_changed()
384        self.assertEqual(state, Runstate.CONNECTING)
385
386        with self.assertRaises(StateError) as context:
387            await self._bad_connection('UNIX')
388
389        self.assertEqual(
390            context.exception.error_message,
391            "NullProtocol is currently connecting."
392        )
393        self.assertEqual(context.exception.state, Runstate.CONNECTING)
394        self.assertEqual(context.exception.required, Runstate.IDLE)
395
396        task.cancel()
397        await task
398
399    @TestBase.async_test
400    async def testImplicitRunstateInit(self):
401        """
402        Test what happens if we do not wait on the runstate event until
403        AFTER a connection is made, i.e., connect()/accept() themselves
404        initialize the runstate event. All of the above tests force the
405        initialization by waiting on the runstate *first*.
406        """
407        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
408
409        # Kick the loop to coerce the state change
410        await asyncio.sleep(0)
411        assert self.proto.runstate == Runstate.CONNECTING
412
413        # We already missed the transition to CONNECTING
414        await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
415
416        task.cancel()
417        await task
418
419
420class Accept(Connect):
421    """
422    All of the same tests as Connect, but using the accept() interface.
423    """
424    async def _bad_connection(self, family: str):
425        assert family in ('INET', 'UNIX')
426
427        if family == 'INET':
428            await self.proto.start_server_and_accept(('example.com', 1))
429        elif family == 'UNIX':
430            await self.proto.start_server_and_accept('/dev/null')
431
432    async def _hanging_connection(self):
433        with TemporaryDirectory(suffix='.qmp') as tmpdir:
434            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
435            await self.proto.start_server_and_accept(sock)
436
437
438class FakeSession(TestBase):
439
440    def setUp(self):
441        super().setUp()
442        self.proto.fake_session = True
443
444    async def _asyncSetUp(self):
445        await super()._asyncSetUp()
446        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
447
448    async def _asyncTearDown(self):
449        await self.proto.disconnect()
450        await super()._asyncTearDown()
451
452    ####
453
454    @TestBase.async_test
455    async def testFakeConnect(self):
456
457        """Test the full state lifecycle (via connect) with a no-op session."""
458        await self.proto.connect('/not/a/real/path')
459        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
460
461    @TestBase.async_test
462    async def testFakeAccept(self):
463        """Test the full state lifecycle (via accept) with a no-op session."""
464        await self.proto.start_server_and_accept('/not/a/real/path')
465        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
466
467    @TestBase.async_test
468    async def testFakeRecv(self):
469        """Test receiving a fake/null message."""
470        await self.proto.start_server_and_accept('/not/a/real/path')
471
472        logname = self.proto.logger.name
473        with self.assertLogs(logname, level='DEBUG') as context:
474            self.proto.trigger_input.set()
475            self.proto.trigger_input.clear()
476            await asyncio.sleep(0)  # Kick reader.
477
478        self.assertEqual(
479            context.output,
480            [f"DEBUG:{logname}:<-- None"],
481        )
482
483    @TestBase.async_test
484    async def testFakeSend(self):
485        """Test sending a fake/null message."""
486        await self.proto.start_server_and_accept('/not/a/real/path')
487
488        logname = self.proto.logger.name
489        with self.assertLogs(logname, level='DEBUG') as context:
490            # Cheat: Send a Null message to nobody.
491            await self.proto.send_msg()
492            # Kick writer; awaiting on a queue.put isn't sufficient to yield.
493            await asyncio.sleep(0)
494
495        self.assertEqual(
496            context.output,
497            [f"DEBUG:{logname}:--> None"],
498        )
499
500    async def _prod_session_api(
501            self,
502            current_state: Runstate,
503            error_message: str,
504            accept: bool = True
505    ):
506        with self.assertRaises(StateError) as context:
507            if accept:
508                await self.proto.start_server_and_accept('/not/a/real/path')
509            else:
510                await self.proto.connect('/not/a/real/path')
511
512        self.assertEqual(context.exception.error_message, error_message)
513        self.assertEqual(context.exception.state, current_state)
514        self.assertEqual(context.exception.required, Runstate.IDLE)
515
516    @TestBase.async_test
517    async def testAcceptRequireRunning(self):
518        """Test that accept() cannot be called when Runstate=RUNNING"""
519        await self.proto.start_server_and_accept('/not/a/real/path')
520
521        await self._prod_session_api(
522            Runstate.RUNNING,
523            "NullProtocol is already connected and running.",
524            accept=True,
525        )
526
527    @TestBase.async_test
528    async def testConnectRequireRunning(self):
529        """Test that connect() cannot be called when Runstate=RUNNING"""
530        await self.proto.start_server_and_accept('/not/a/real/path')
531
532        await self._prod_session_api(
533            Runstate.RUNNING,
534            "NullProtocol is already connected and running.",
535            accept=False,
536        )
537
538    @TestBase.async_test
539    async def testAcceptRequireDisconnecting(self):
540        """Test that accept() cannot be called when Runstate=DISCONNECTING"""
541        await self.proto.start_server_and_accept('/not/a/real/path')
542
543        # Cheat: force a disconnect.
544        await self.proto.simulate_disconnect()
545
546        await self._prod_session_api(
547            Runstate.DISCONNECTING,
548            ("NullProtocol is disconnecting."
549             " Call disconnect() to return to IDLE state."),
550            accept=True,
551        )
552
553    @TestBase.async_test
554    async def testConnectRequireDisconnecting(self):
555        """Test that connect() cannot be called when Runstate=DISCONNECTING"""
556        await self.proto.start_server_and_accept('/not/a/real/path')
557
558        # Cheat: force a disconnect.
559        await self.proto.simulate_disconnect()
560
561        await self._prod_session_api(
562            Runstate.DISCONNECTING,
563            ("NullProtocol is disconnecting."
564             " Call disconnect() to return to IDLE state."),
565            accept=False,
566        )
567
568
569class SimpleSession(TestBase):
570
571    def setUp(self):
572        super().setUp()
573        self.server = LineProtocol(type(self).__name__ + '-server')
574
575    async def _asyncSetUp(self):
576        await super()._asyncSetUp()
577        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
578
579    async def _asyncTearDown(self):
580        await self.proto.disconnect()
581        try:
582            await self.server.disconnect()
583        except EOFError:
584            pass
585        await super()._asyncTearDown()
586
587    @TestBase.async_test
588    async def testSmoke(self):
589        with TemporaryDirectory(suffix='.qmp') as tmpdir:
590            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
591            server_task = asyncio.create_task(
592                self.server.start_server_and_accept(sock))
593
594            # give the server a chance to start listening [...]
595            await asyncio.sleep(0)
596            await self.proto.connect(sock)
597