xref: /openbmc/qemu/python/tests/protocol.py (revision 5e9902a0)
1import asyncio
2from contextlib import contextmanager
3import os
4import socket
5from tempfile import TemporaryDirectory
6
7import avocado
8
9from qemu.aqmp import ConnectError, Runstate
10from qemu.aqmp.protocol import AsyncProtocol, StateError
11from qemu.aqmp.util import asyncio_run, create_task
12
13
14class NullProtocol(AsyncProtocol[None]):
15    """
16    NullProtocol is a test mockup of an AsyncProtocol implementation.
17
18    It adds a fake_session instance variable that enables a code path
19    that bypasses the actual connection logic, but still allows the
20    reader/writers to start.
21
22    Because the message type is defined as None, an asyncio.Event named
23    'trigger_input' is created that prohibits the reader from
24    incessantly being able to yield None; this event can be poked to
25    simulate an incoming message.
26
27    For testing symmetry with do_recv, an interface is added to "send" a
28    Null message.
29
30    For testing purposes, a "simulate_disconnection" method is also
31    added which allows us to trigger a bottom half disconnect without
32    injecting any real errors into the reader/writer loops; in essence
33    it performs exactly half of what disconnect() normally does.
34    """
35    def __init__(self, name=None):
36        self.fake_session = False
37        self.trigger_input: asyncio.Event
38        super().__init__(name)
39
40    async def _establish_session(self):
41        self.trigger_input = asyncio.Event()
42        await super()._establish_session()
43
44    async def _do_start_server(self, address, ssl=None):
45        if self.fake_session:
46            self._set_state(Runstate.CONNECTING)
47            await asyncio.sleep(0)
48        else:
49            await super()._do_start_server(address, ssl)
50
51    async def _do_connect(self, address, ssl=None):
52        if self.fake_session:
53            self._set_state(Runstate.CONNECTING)
54            await asyncio.sleep(0)
55        else:
56            await super()._do_connect(address, ssl)
57
58    async def _do_recv(self) -> None:
59        await self.trigger_input.wait()
60        self.trigger_input.clear()
61
62    def _do_send(self, msg: None) -> None:
63        pass
64
65    async def send_msg(self) -> None:
66        await self._outgoing.put(None)
67
68    async def simulate_disconnect(self) -> None:
69        """
70        Simulates a bottom-half disconnect.
71
72        This method schedules a disconnection but does not wait for it
73        to complete. This is used to put the loop into the DISCONNECTING
74        state without fully quiescing it back to IDLE. This is normally
75        something you cannot coax AsyncProtocol to do on purpose, but it
76        will be similar to what happens with an unhandled Exception in
77        the reader/writer.
78
79        Under normal circumstances, the library design requires you to
80        await on disconnect(), which awaits the disconnect task and
81        returns bottom half errors as a pre-condition to allowing the
82        loop to return back to IDLE.
83        """
84        self._schedule_disconnect()
85
86
87class LineProtocol(AsyncProtocol[str]):
88    def __init__(self, name=None):
89        super().__init__(name)
90        self.rx_history = []
91
92    async def _do_recv(self) -> str:
93        raw = await self._readline()
94        msg = raw.decode()
95        self.rx_history.append(msg)
96        return msg
97
98    def _do_send(self, msg: str) -> None:
99        assert self._writer is not None
100        self._writer.write(msg.encode() + b'\n')
101
102    async def send_msg(self, msg: str) -> None:
103        await self._outgoing.put(msg)
104
105
106def run_as_task(coro, allow_cancellation=False):
107    """
108    Run a given coroutine as a task.
109
110    Optionally, wrap it in a try..except block that allows this
111    coroutine to be canceled gracefully.
112    """
113    async def _runner():
114        try:
115            await coro
116        except asyncio.CancelledError:
117            if allow_cancellation:
118                return
119            raise
120    return create_task(_runner())
121
122
123@contextmanager
124def jammed_socket():
125    """
126    Opens up a random unused TCP port on localhost, then jams it.
127    """
128    socks = []
129
130    try:
131        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
132        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
133        sock.bind(('127.0.0.1', 0))
134        sock.listen(1)
135        address = sock.getsockname()
136
137        socks.append(sock)
138
139        # I don't *fully* understand why, but it takes *two* un-accepted
140        # connections to start jamming the socket.
141        for _ in range(2):
142            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143            sock.connect(address)
144            socks.append(sock)
145
146        yield address
147
148    finally:
149        for sock in socks:
150            sock.close()
151
152
153class Smoke(avocado.Test):
154
155    def setUp(self):
156        self.proto = NullProtocol()
157
158    def test__repr__(self):
159        self.assertEqual(
160            repr(self.proto),
161            "<NullProtocol runstate=IDLE>"
162        )
163
164    def testRunstate(self):
165        self.assertEqual(
166            self.proto.runstate,
167            Runstate.IDLE
168        )
169
170    def testDefaultName(self):
171        self.assertEqual(
172            self.proto.name,
173            None
174        )
175
176    def testLogger(self):
177        self.assertEqual(
178            self.proto.logger.name,
179            'qemu.aqmp.protocol'
180        )
181
182    def testName(self):
183        self.proto = NullProtocol('Steve')
184
185        self.assertEqual(
186            self.proto.name,
187            'Steve'
188        )
189
190        self.assertEqual(
191            self.proto.logger.name,
192            'qemu.aqmp.protocol.Steve'
193        )
194
195        self.assertEqual(
196            repr(self.proto),
197            "<NullProtocol name='Steve' runstate=IDLE>"
198        )
199
200
201class TestBase(avocado.Test):
202
203    def setUp(self):
204        self.proto = NullProtocol(type(self).__name__)
205        self.assertEqual(self.proto.runstate, Runstate.IDLE)
206        self.runstate_watcher = None
207
208    def tearDown(self):
209        self.assertEqual(self.proto.runstate, Runstate.IDLE)
210
211    async def _asyncSetUp(self):
212        pass
213
214    async def _asyncTearDown(self):
215        if self.runstate_watcher:
216            await self.runstate_watcher
217
218    @staticmethod
219    def async_test(async_test_method):
220        """
221        Decorator; adds SetUp and TearDown to async tests.
222        """
223        async def _wrapper(self, *args, **kwargs):
224            loop = asyncio.get_event_loop()
225            loop.set_debug(True)
226
227            await self._asyncSetUp()
228            await async_test_method(self, *args, **kwargs)
229            await self._asyncTearDown()
230
231        return _wrapper
232
233    # Definitions
234
235    # The states we expect a "bad" connect/accept attempt to transition through
236    BAD_CONNECTION_STATES = (
237        Runstate.CONNECTING,
238        Runstate.DISCONNECTING,
239        Runstate.IDLE,
240    )
241
242    # The states we expect a "good" session to transition through
243    GOOD_CONNECTION_STATES = (
244        Runstate.CONNECTING,
245        Runstate.RUNNING,
246        Runstate.DISCONNECTING,
247        Runstate.IDLE,
248    )
249
250    # Helpers
251
252    async def _watch_runstates(self, *states):
253        """
254        This launches a task alongside (most) tests below to confirm that
255        the sequence of runstate changes that occur is exactly as
256        anticipated.
257        """
258        async def _watcher():
259            for state in states:
260                new_state = await self.proto.runstate_changed()
261                self.assertEqual(
262                    new_state,
263                    state,
264                    msg=f"Expected state '{state.name}'",
265                )
266
267        self.runstate_watcher = create_task(_watcher())
268        # Kick the loop and force the task to block on the event.
269        await asyncio.sleep(0)
270
271
272class State(TestBase):
273
274    @TestBase.async_test
275    async def testSuperfluousDisconnect(self):
276        """
277        Test calling disconnect() while already disconnected.
278        """
279        await self._watch_runstates(
280            Runstate.DISCONNECTING,
281            Runstate.IDLE,
282        )
283        await self.proto.disconnect()
284
285
286class Connect(TestBase):
287    """
288    Tests primarily related to calling Connect().
289    """
290    async def _bad_connection(self, family: str):
291        assert family in ('INET', 'UNIX')
292
293        if family == 'INET':
294            await self.proto.connect(('127.0.0.1', 0))
295        elif family == 'UNIX':
296            await self.proto.connect('/dev/null')
297
298    async def _hanging_connection(self):
299        with jammed_socket() as addr:
300            await self.proto.connect(addr)
301
302    async def _bad_connection_test(self, family: str):
303        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
304
305        with self.assertRaises(ConnectError) as context:
306            await self._bad_connection(family)
307
308        self.assertIsInstance(context.exception.exc, OSError)
309        self.assertEqual(
310            context.exception.error_message,
311            "Failed to establish connection"
312        )
313
314    @TestBase.async_test
315    async def testBadINET(self):
316        """
317        Test an immediately rejected call to an IP target.
318        """
319        await self._bad_connection_test('INET')
320
321    @TestBase.async_test
322    async def testBadUNIX(self):
323        """
324        Test an immediately rejected call to a UNIX socket target.
325        """
326        await self._bad_connection_test('UNIX')
327
328    @TestBase.async_test
329    async def testCancellation(self):
330        """
331        Test what happens when a connection attempt is aborted.
332        """
333        # Note that accept() cannot be cancelled outright, as it isn't a task.
334        # However, we can wrap it in a task and cancel *that*.
335        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
336        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
337
338        state = await self.proto.runstate_changed()
339        self.assertEqual(state, Runstate.CONNECTING)
340
341        # This is insider baseball, but the connection attempt has
342        # yielded *just* before the actual connection attempt, so kick
343        # the loop to make sure it's truly wedged.
344        await asyncio.sleep(0)
345
346        task.cancel()
347        await task
348
349    @TestBase.async_test
350    async def testTimeout(self):
351        """
352        Test what happens when a connection attempt times out.
353        """
354        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
355        task = run_as_task(self._hanging_connection())
356
357        # More insider baseball: to improve the speed of this test while
358        # guaranteeing that the connection even gets a chance to start,
359        # verify that the connection hangs *first*, then await the
360        # result of the task with a nearly-zero timeout.
361
362        state = await self.proto.runstate_changed()
363        self.assertEqual(state, Runstate.CONNECTING)
364        await asyncio.sleep(0)
365
366        with self.assertRaises(asyncio.TimeoutError):
367            await asyncio.wait_for(task, timeout=0)
368
369    @TestBase.async_test
370    async def testRequire(self):
371        """
372        Test what happens when a connection attempt is made while CONNECTING.
373        """
374        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
375        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
376
377        state = await self.proto.runstate_changed()
378        self.assertEqual(state, Runstate.CONNECTING)
379
380        with self.assertRaises(StateError) as context:
381            await self._bad_connection('UNIX')
382
383        self.assertEqual(
384            context.exception.error_message,
385            "NullProtocol is currently connecting."
386        )
387        self.assertEqual(context.exception.state, Runstate.CONNECTING)
388        self.assertEqual(context.exception.required, Runstate.IDLE)
389
390        task.cancel()
391        await task
392
393    @TestBase.async_test
394    async def testImplicitRunstateInit(self):
395        """
396        Test what happens if we do not wait on the runstate event until
397        AFTER a connection is made, i.e., connect()/accept() themselves
398        initialize the runstate event. All of the above tests force the
399        initialization by waiting on the runstate *first*.
400        """
401        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
402
403        # Kick the loop to coerce the state change
404        await asyncio.sleep(0)
405        assert self.proto.runstate == Runstate.CONNECTING
406
407        # We already missed the transition to CONNECTING
408        await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
409
410        task.cancel()
411        await task
412
413
414class Accept(Connect):
415    """
416    All of the same tests as Connect, but using the accept() interface.
417    """
418    async def _bad_connection(self, family: str):
419        assert family in ('INET', 'UNIX')
420
421        if family == 'INET':
422            await self.proto.start_server_and_accept(('example.com', 1))
423        elif family == 'UNIX':
424            await self.proto.start_server_and_accept('/dev/null')
425
426    async def _hanging_connection(self):
427        with TemporaryDirectory(suffix='.aqmp') as tmpdir:
428            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
429            await self.proto.start_server_and_accept(sock)
430
431
432class FakeSession(TestBase):
433
434    def setUp(self):
435        super().setUp()
436        self.proto.fake_session = True
437
438    async def _asyncSetUp(self):
439        await super()._asyncSetUp()
440        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
441
442    async def _asyncTearDown(self):
443        await self.proto.disconnect()
444        await super()._asyncTearDown()
445
446    ####
447
448    @TestBase.async_test
449    async def testFakeConnect(self):
450
451        """Test the full state lifecycle (via connect) with a no-op session."""
452        await self.proto.connect('/not/a/real/path')
453        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
454
455    @TestBase.async_test
456    async def testFakeAccept(self):
457        """Test the full state lifecycle (via accept) with a no-op session."""
458        await self.proto.start_server_and_accept('/not/a/real/path')
459        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
460
461    @TestBase.async_test
462    async def testFakeRecv(self):
463        """Test receiving a fake/null message."""
464        await self.proto.start_server_and_accept('/not/a/real/path')
465
466        logname = self.proto.logger.name
467        with self.assertLogs(logname, level='DEBUG') as context:
468            self.proto.trigger_input.set()
469            self.proto.trigger_input.clear()
470            await asyncio.sleep(0)  # Kick reader.
471
472        self.assertEqual(
473            context.output,
474            [f"DEBUG:{logname}:<-- None"],
475        )
476
477    @TestBase.async_test
478    async def testFakeSend(self):
479        """Test sending a fake/null message."""
480        await self.proto.start_server_and_accept('/not/a/real/path')
481
482        logname = self.proto.logger.name
483        with self.assertLogs(logname, level='DEBUG') as context:
484            # Cheat: Send a Null message to nobody.
485            await self.proto.send_msg()
486            # Kick writer; awaiting on a queue.put isn't sufficient to yield.
487            await asyncio.sleep(0)
488
489        self.assertEqual(
490            context.output,
491            [f"DEBUG:{logname}:--> None"],
492        )
493
494    async def _prod_session_api(
495            self,
496            current_state: Runstate,
497            error_message: str,
498            accept: bool = True
499    ):
500        with self.assertRaises(StateError) as context:
501            if accept:
502                await self.proto.start_server_and_accept('/not/a/real/path')
503            else:
504                await self.proto.connect('/not/a/real/path')
505
506        self.assertEqual(context.exception.error_message, error_message)
507        self.assertEqual(context.exception.state, current_state)
508        self.assertEqual(context.exception.required, Runstate.IDLE)
509
510    @TestBase.async_test
511    async def testAcceptRequireRunning(self):
512        """Test that accept() cannot be called when Runstate=RUNNING"""
513        await self.proto.start_server_and_accept('/not/a/real/path')
514
515        await self._prod_session_api(
516            Runstate.RUNNING,
517            "NullProtocol is already connected and running.",
518            accept=True,
519        )
520
521    @TestBase.async_test
522    async def testConnectRequireRunning(self):
523        """Test that connect() cannot be called when Runstate=RUNNING"""
524        await self.proto.start_server_and_accept('/not/a/real/path')
525
526        await self._prod_session_api(
527            Runstate.RUNNING,
528            "NullProtocol is already connected and running.",
529            accept=False,
530        )
531
532    @TestBase.async_test
533    async def testAcceptRequireDisconnecting(self):
534        """Test that accept() cannot be called when Runstate=DISCONNECTING"""
535        await self.proto.start_server_and_accept('/not/a/real/path')
536
537        # Cheat: force a disconnect.
538        await self.proto.simulate_disconnect()
539
540        await self._prod_session_api(
541            Runstate.DISCONNECTING,
542            ("NullProtocol is disconnecting."
543             " Call disconnect() to return to IDLE state."),
544            accept=True,
545        )
546
547    @TestBase.async_test
548    async def testConnectRequireDisconnecting(self):
549        """Test that connect() cannot be called when Runstate=DISCONNECTING"""
550        await self.proto.start_server_and_accept('/not/a/real/path')
551
552        # Cheat: force a disconnect.
553        await self.proto.simulate_disconnect()
554
555        await self._prod_session_api(
556            Runstate.DISCONNECTING,
557            ("NullProtocol is disconnecting."
558             " Call disconnect() to return to IDLE state."),
559            accept=False,
560        )
561
562
563class SimpleSession(TestBase):
564
565    def setUp(self):
566        super().setUp()
567        self.server = LineProtocol(type(self).__name__ + '-server')
568
569    async def _asyncSetUp(self):
570        await super()._asyncSetUp()
571        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
572
573    async def _asyncTearDown(self):
574        await self.proto.disconnect()
575        try:
576            await self.server.disconnect()
577        except EOFError:
578            pass
579        await super()._asyncTearDown()
580
581    @TestBase.async_test
582    async def testSmoke(self):
583        with TemporaryDirectory(suffix='.aqmp') as tmpdir:
584            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
585            server_task = create_task(self.server.start_server_and_accept(sock))
586
587            # give the server a chance to start listening [...]
588            await asyncio.sleep(0)
589            await self.proto.connect(sock)
590