xref: /openbmc/qemu/python/tests/protocol.py (revision 6016b7b4)
1import asyncio
2from contextlib import contextmanager
3import os
4import socket
5from tempfile import TemporaryDirectory
6
7import avocado
8
9from qemu.aqmp import ConnectError, Runstate
10from qemu.aqmp.protocol import AsyncProtocol, StateError
11from qemu.aqmp.util import asyncio_run, create_task
12
13
14class NullProtocol(AsyncProtocol[None]):
15    """
16    NullProtocol is a test mockup of an AsyncProtocol implementation.
17
18    It adds a fake_session instance variable that enables a code path
19    that bypasses the actual connection logic, but still allows the
20    reader/writers to start.
21
22    Because the message type is defined as None, an asyncio.Event named
23    'trigger_input' is created that prohibits the reader from
24    incessantly being able to yield None; this event can be poked to
25    simulate an incoming message.
26
27    For testing symmetry with do_recv, an interface is added to "send" a
28    Null message.
29
30    For testing purposes, a "simulate_disconnection" method is also
31    added which allows us to trigger a bottom half disconnect without
32    injecting any real errors into the reader/writer loops; in essence
33    it performs exactly half of what disconnect() normally does.
34    """
35    def __init__(self, name=None):
36        self.fake_session = False
37        self.trigger_input: asyncio.Event
38        super().__init__(name)
39
40    async def _establish_session(self):
41        self.trigger_input = asyncio.Event()
42        await super()._establish_session()
43
44    async def _do_accept(self, address, ssl=None):
45        if not self.fake_session:
46            await super()._do_accept(address, ssl)
47
48    async def _do_connect(self, address, ssl=None):
49        if not self.fake_session:
50            await super()._do_connect(address, ssl)
51
52    async def _do_recv(self) -> None:
53        await self.trigger_input.wait()
54        self.trigger_input.clear()
55
56    def _do_send(self, msg: None) -> None:
57        pass
58
59    async def send_msg(self) -> None:
60        await self._outgoing.put(None)
61
62    async def simulate_disconnect(self) -> None:
63        """
64        Simulates a bottom-half disconnect.
65
66        This method schedules a disconnection but does not wait for it
67        to complete. This is used to put the loop into the DISCONNECTING
68        state without fully quiescing it back to IDLE. This is normally
69        something you cannot coax AsyncProtocol to do on purpose, but it
70        will be similar to what happens with an unhandled Exception in
71        the reader/writer.
72
73        Under normal circumstances, the library design requires you to
74        await on disconnect(), which awaits the disconnect task and
75        returns bottom half errors as a pre-condition to allowing the
76        loop to return back to IDLE.
77        """
78        self._schedule_disconnect()
79
80
81class LineProtocol(AsyncProtocol[str]):
82    def __init__(self, name=None):
83        super().__init__(name)
84        self.rx_history = []
85
86    async def _do_recv(self) -> str:
87        raw = await self._readline()
88        msg = raw.decode()
89        self.rx_history.append(msg)
90        return msg
91
92    def _do_send(self, msg: str) -> None:
93        assert self._writer is not None
94        self._writer.write(msg.encode() + b'\n')
95
96    async def send_msg(self, msg: str) -> None:
97        await self._outgoing.put(msg)
98
99
100def run_as_task(coro, allow_cancellation=False):
101    """
102    Run a given coroutine as a task.
103
104    Optionally, wrap it in a try..except block that allows this
105    coroutine to be canceled gracefully.
106    """
107    async def _runner():
108        try:
109            await coro
110        except asyncio.CancelledError:
111            if allow_cancellation:
112                return
113            raise
114    return create_task(_runner())
115
116
117@contextmanager
118def jammed_socket():
119    """
120    Opens up a random unused TCP port on localhost, then jams it.
121    """
122    socks = []
123
124    try:
125        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
126        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
127        sock.bind(('127.0.0.1', 0))
128        sock.listen(1)
129        address = sock.getsockname()
130
131        socks.append(sock)
132
133        # I don't *fully* understand why, but it takes *two* un-accepted
134        # connections to start jamming the socket.
135        for _ in range(2):
136            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
137            sock.connect(address)
138            socks.append(sock)
139
140        yield address
141
142    finally:
143        for sock in socks:
144            sock.close()
145
146
147class Smoke(avocado.Test):
148
149    def setUp(self):
150        self.proto = NullProtocol()
151
152    def test__repr__(self):
153        self.assertEqual(
154            repr(self.proto),
155            "<NullProtocol runstate=IDLE>"
156        )
157
158    def testRunstate(self):
159        self.assertEqual(
160            self.proto.runstate,
161            Runstate.IDLE
162        )
163
164    def testDefaultName(self):
165        self.assertEqual(
166            self.proto.name,
167            None
168        )
169
170    def testLogger(self):
171        self.assertEqual(
172            self.proto.logger.name,
173            'qemu.aqmp.protocol'
174        )
175
176    def testName(self):
177        self.proto = NullProtocol('Steve')
178
179        self.assertEqual(
180            self.proto.name,
181            'Steve'
182        )
183
184        self.assertEqual(
185            self.proto.logger.name,
186            'qemu.aqmp.protocol.Steve'
187        )
188
189        self.assertEqual(
190            repr(self.proto),
191            "<NullProtocol name='Steve' runstate=IDLE>"
192        )
193
194
195class TestBase(avocado.Test):
196
197    def setUp(self):
198        self.proto = NullProtocol(type(self).__name__)
199        self.assertEqual(self.proto.runstate, Runstate.IDLE)
200        self.runstate_watcher = None
201
202    def tearDown(self):
203        self.assertEqual(self.proto.runstate, Runstate.IDLE)
204
205    async def _asyncSetUp(self):
206        pass
207
208    async def _asyncTearDown(self):
209        if self.runstate_watcher:
210            await self.runstate_watcher
211
212    @staticmethod
213    def async_test(async_test_method):
214        """
215        Decorator; adds SetUp and TearDown to async tests.
216        """
217        async def _wrapper(self, *args, **kwargs):
218            loop = asyncio.get_event_loop()
219            loop.set_debug(True)
220
221            await self._asyncSetUp()
222            await async_test_method(self, *args, **kwargs)
223            await self._asyncTearDown()
224
225        return _wrapper
226
227    # Definitions
228
229    # The states we expect a "bad" connect/accept attempt to transition through
230    BAD_CONNECTION_STATES = (
231        Runstate.CONNECTING,
232        Runstate.DISCONNECTING,
233        Runstate.IDLE,
234    )
235
236    # The states we expect a "good" session to transition through
237    GOOD_CONNECTION_STATES = (
238        Runstate.CONNECTING,
239        Runstate.RUNNING,
240        Runstate.DISCONNECTING,
241        Runstate.IDLE,
242    )
243
244    # Helpers
245
246    async def _watch_runstates(self, *states):
247        """
248        This launches a task alongside (most) tests below to confirm that
249        the sequence of runstate changes that occur is exactly as
250        anticipated.
251        """
252        async def _watcher():
253            for state in states:
254                new_state = await self.proto.runstate_changed()
255                self.assertEqual(
256                    new_state,
257                    state,
258                    msg=f"Expected state '{state.name}'",
259                )
260
261        self.runstate_watcher = create_task(_watcher())
262        # Kick the loop and force the task to block on the event.
263        await asyncio.sleep(0)
264
265
266class State(TestBase):
267
268    @TestBase.async_test
269    async def testSuperfluousDisconnect(self):
270        """
271        Test calling disconnect() while already disconnected.
272        """
273        await self._watch_runstates(
274            Runstate.DISCONNECTING,
275            Runstate.IDLE,
276        )
277        await self.proto.disconnect()
278
279
280class Connect(TestBase):
281    """
282    Tests primarily related to calling Connect().
283    """
284    async def _bad_connection(self, family: str):
285        assert family in ('INET', 'UNIX')
286
287        if family == 'INET':
288            await self.proto.connect(('127.0.0.1', 0))
289        elif family == 'UNIX':
290            await self.proto.connect('/dev/null')
291
292    async def _hanging_connection(self):
293        with jammed_socket() as addr:
294            await self.proto.connect(addr)
295
296    async def _bad_connection_test(self, family: str):
297        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
298
299        with self.assertRaises(ConnectError) as context:
300            await self._bad_connection(family)
301
302        self.assertIsInstance(context.exception.exc, OSError)
303        self.assertEqual(
304            context.exception.error_message,
305            "Failed to establish connection"
306        )
307
308    @TestBase.async_test
309    async def testBadINET(self):
310        """
311        Test an immediately rejected call to an IP target.
312        """
313        await self._bad_connection_test('INET')
314
315    @TestBase.async_test
316    async def testBadUNIX(self):
317        """
318        Test an immediately rejected call to a UNIX socket target.
319        """
320        await self._bad_connection_test('UNIX')
321
322    @TestBase.async_test
323    async def testCancellation(self):
324        """
325        Test what happens when a connection attempt is aborted.
326        """
327        # Note that accept() cannot be cancelled outright, as it isn't a task.
328        # However, we can wrap it in a task and cancel *that*.
329        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
330        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
331
332        state = await self.proto.runstate_changed()
333        self.assertEqual(state, Runstate.CONNECTING)
334
335        # This is insider baseball, but the connection attempt has
336        # yielded *just* before the actual connection attempt, so kick
337        # the loop to make sure it's truly wedged.
338        await asyncio.sleep(0)
339
340        task.cancel()
341        await task
342
343    @TestBase.async_test
344    async def testTimeout(self):
345        """
346        Test what happens when a connection attempt times out.
347        """
348        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
349        task = run_as_task(self._hanging_connection())
350
351        # More insider baseball: to improve the speed of this test while
352        # guaranteeing that the connection even gets a chance to start,
353        # verify that the connection hangs *first*, then await the
354        # result of the task with a nearly-zero timeout.
355
356        state = await self.proto.runstate_changed()
357        self.assertEqual(state, Runstate.CONNECTING)
358        await asyncio.sleep(0)
359
360        with self.assertRaises(asyncio.TimeoutError):
361            await asyncio.wait_for(task, timeout=0)
362
363    @TestBase.async_test
364    async def testRequire(self):
365        """
366        Test what happens when a connection attempt is made while CONNECTING.
367        """
368        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
369        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
370
371        state = await self.proto.runstate_changed()
372        self.assertEqual(state, Runstate.CONNECTING)
373
374        with self.assertRaises(StateError) as context:
375            await self._bad_connection('UNIX')
376
377        self.assertEqual(
378            context.exception.error_message,
379            "NullProtocol is currently connecting."
380        )
381        self.assertEqual(context.exception.state, Runstate.CONNECTING)
382        self.assertEqual(context.exception.required, Runstate.IDLE)
383
384        task.cancel()
385        await task
386
387    @TestBase.async_test
388    async def testImplicitRunstateInit(self):
389        """
390        Test what happens if we do not wait on the runstate event until
391        AFTER a connection is made, i.e., connect()/accept() themselves
392        initialize the runstate event. All of the above tests force the
393        initialization by waiting on the runstate *first*.
394        """
395        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
396
397        # Kick the loop to coerce the state change
398        await asyncio.sleep(0)
399        assert self.proto.runstate == Runstate.CONNECTING
400
401        # We already missed the transition to CONNECTING
402        await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
403
404        task.cancel()
405        await task
406
407
408class Accept(Connect):
409    """
410    All of the same tests as Connect, but using the accept() interface.
411    """
412    async def _bad_connection(self, family: str):
413        assert family in ('INET', 'UNIX')
414
415        if family == 'INET':
416            await self.proto.accept(('example.com', 1))
417        elif family == 'UNIX':
418            await self.proto.accept('/dev/null')
419
420    async def _hanging_connection(self):
421        with TemporaryDirectory(suffix='.aqmp') as tmpdir:
422            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
423            await self.proto.accept(sock)
424
425
426class FakeSession(TestBase):
427
428    def setUp(self):
429        super().setUp()
430        self.proto.fake_session = True
431
432    async def _asyncSetUp(self):
433        await super()._asyncSetUp()
434        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
435
436    async def _asyncTearDown(self):
437        await self.proto.disconnect()
438        await super()._asyncTearDown()
439
440    ####
441
442    @TestBase.async_test
443    async def testFakeConnect(self):
444
445        """Test the full state lifecycle (via connect) with a no-op session."""
446        await self.proto.connect('/not/a/real/path')
447        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
448
449    @TestBase.async_test
450    async def testFakeAccept(self):
451        """Test the full state lifecycle (via accept) with a no-op session."""
452        await self.proto.accept('/not/a/real/path')
453        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
454
455    @TestBase.async_test
456    async def testFakeRecv(self):
457        """Test receiving a fake/null message."""
458        await self.proto.accept('/not/a/real/path')
459
460        logname = self.proto.logger.name
461        with self.assertLogs(logname, level='DEBUG') as context:
462            self.proto.trigger_input.set()
463            self.proto.trigger_input.clear()
464            await asyncio.sleep(0)  # Kick reader.
465
466        self.assertEqual(
467            context.output,
468            [f"DEBUG:{logname}:<-- None"],
469        )
470
471    @TestBase.async_test
472    async def testFakeSend(self):
473        """Test sending a fake/null message."""
474        await self.proto.accept('/not/a/real/path')
475
476        logname = self.proto.logger.name
477        with self.assertLogs(logname, level='DEBUG') as context:
478            # Cheat: Send a Null message to nobody.
479            await self.proto.send_msg()
480            # Kick writer; awaiting on a queue.put isn't sufficient to yield.
481            await asyncio.sleep(0)
482
483        self.assertEqual(
484            context.output,
485            [f"DEBUG:{logname}:--> None"],
486        )
487
488    async def _prod_session_api(
489            self,
490            current_state: Runstate,
491            error_message: str,
492            accept: bool = True
493    ):
494        with self.assertRaises(StateError) as context:
495            if accept:
496                await self.proto.accept('/not/a/real/path')
497            else:
498                await self.proto.connect('/not/a/real/path')
499
500        self.assertEqual(context.exception.error_message, error_message)
501        self.assertEqual(context.exception.state, current_state)
502        self.assertEqual(context.exception.required, Runstate.IDLE)
503
504    @TestBase.async_test
505    async def testAcceptRequireRunning(self):
506        """Test that accept() cannot be called when Runstate=RUNNING"""
507        await self.proto.accept('/not/a/real/path')
508
509        await self._prod_session_api(
510            Runstate.RUNNING,
511            "NullProtocol is already connected and running.",
512            accept=True,
513        )
514
515    @TestBase.async_test
516    async def testConnectRequireRunning(self):
517        """Test that connect() cannot be called when Runstate=RUNNING"""
518        await self.proto.accept('/not/a/real/path')
519
520        await self._prod_session_api(
521            Runstate.RUNNING,
522            "NullProtocol is already connected and running.",
523            accept=False,
524        )
525
526    @TestBase.async_test
527    async def testAcceptRequireDisconnecting(self):
528        """Test that accept() cannot be called when Runstate=DISCONNECTING"""
529        await self.proto.accept('/not/a/real/path')
530
531        # Cheat: force a disconnect.
532        await self.proto.simulate_disconnect()
533
534        await self._prod_session_api(
535            Runstate.DISCONNECTING,
536            ("NullProtocol is disconnecting."
537             " Call disconnect() to return to IDLE state."),
538            accept=True,
539        )
540
541    @TestBase.async_test
542    async def testConnectRequireDisconnecting(self):
543        """Test that connect() cannot be called when Runstate=DISCONNECTING"""
544        await self.proto.accept('/not/a/real/path')
545
546        # Cheat: force a disconnect.
547        await self.proto.simulate_disconnect()
548
549        await self._prod_session_api(
550            Runstate.DISCONNECTING,
551            ("NullProtocol is disconnecting."
552             " Call disconnect() to return to IDLE state."),
553            accept=False,
554        )
555
556
557class SimpleSession(TestBase):
558
559    def setUp(self):
560        super().setUp()
561        self.server = LineProtocol(type(self).__name__ + '-server')
562
563    async def _asyncSetUp(self):
564        await super()._asyncSetUp()
565        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
566
567    async def _asyncTearDown(self):
568        await self.proto.disconnect()
569        try:
570            await self.server.disconnect()
571        except EOFError:
572            pass
573        await super()._asyncTearDown()
574
575    @TestBase.async_test
576    async def testSmoke(self):
577        with TemporaryDirectory(suffix='.aqmp') as tmpdir:
578            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
579            server_task = create_task(self.server.accept(sock))
580
581            # give the server a chance to start listening [...]
582            await asyncio.sleep(0)
583            await self.proto.connect(sock)
584