xref: /openbmc/qemu/python/tests/protocol.py (revision b14df228)
1import asyncio
2from contextlib import contextmanager
3import os
4import socket
5from tempfile import TemporaryDirectory
6
7import avocado
8
9from qemu.qmp import ConnectError, Runstate
10from qemu.qmp.protocol import AsyncProtocol, StateError
11from qemu.qmp.util import asyncio_run, create_task
12
13
14class NullProtocol(AsyncProtocol[None]):
15    """
16    NullProtocol is a test mockup of an AsyncProtocol implementation.
17
18    It adds a fake_session instance variable that enables a code path
19    that bypasses the actual connection logic, but still allows the
20    reader/writers to start.
21
22    Because the message type is defined as None, an asyncio.Event named
23    'trigger_input' is created that prohibits the reader from
24    incessantly being able to yield None; this event can be poked to
25    simulate an incoming message.
26
27    For testing symmetry with do_recv, an interface is added to "send" a
28    Null message.
29
30    For testing purposes, a "simulate_disconnection" method is also
31    added which allows us to trigger a bottom half disconnect without
32    injecting any real errors into the reader/writer loops; in essence
33    it performs exactly half of what disconnect() normally does.
34    """
35    def __init__(self, name=None):
36        self.fake_session = False
37        self.trigger_input: asyncio.Event
38        super().__init__(name)
39
40    async def _establish_session(self):
41        self.trigger_input = asyncio.Event()
42        await super()._establish_session()
43
44    async def _do_start_server(self, address, ssl=None):
45        if self.fake_session:
46            self._accepted = asyncio.Event()
47            self._set_state(Runstate.CONNECTING)
48            await asyncio.sleep(0)
49        else:
50            await super()._do_start_server(address, ssl)
51
52    async def _do_accept(self):
53        if self.fake_session:
54            self._accepted = None
55        else:
56            await super()._do_accept()
57
58    async def _do_connect(self, address, ssl=None):
59        if self.fake_session:
60            self._set_state(Runstate.CONNECTING)
61            await asyncio.sleep(0)
62        else:
63            await super()._do_connect(address, ssl)
64
65    async def _do_recv(self) -> None:
66        await self.trigger_input.wait()
67        self.trigger_input.clear()
68
69    def _do_send(self, msg: None) -> None:
70        pass
71
72    async def send_msg(self) -> None:
73        await self._outgoing.put(None)
74
75    async def simulate_disconnect(self) -> None:
76        """
77        Simulates a bottom-half disconnect.
78
79        This method schedules a disconnection but does not wait for it
80        to complete. This is used to put the loop into the DISCONNECTING
81        state without fully quiescing it back to IDLE. This is normally
82        something you cannot coax AsyncProtocol to do on purpose, but it
83        will be similar to what happens with an unhandled Exception in
84        the reader/writer.
85
86        Under normal circumstances, the library design requires you to
87        await on disconnect(), which awaits the disconnect task and
88        returns bottom half errors as a pre-condition to allowing the
89        loop to return back to IDLE.
90        """
91        self._schedule_disconnect()
92
93
94class LineProtocol(AsyncProtocol[str]):
95    def __init__(self, name=None):
96        super().__init__(name)
97        self.rx_history = []
98
99    async def _do_recv(self) -> str:
100        raw = await self._readline()
101        msg = raw.decode()
102        self.rx_history.append(msg)
103        return msg
104
105    def _do_send(self, msg: str) -> None:
106        assert self._writer is not None
107        self._writer.write(msg.encode() + b'\n')
108
109    async def send_msg(self, msg: str) -> None:
110        await self._outgoing.put(msg)
111
112
113def run_as_task(coro, allow_cancellation=False):
114    """
115    Run a given coroutine as a task.
116
117    Optionally, wrap it in a try..except block that allows this
118    coroutine to be canceled gracefully.
119    """
120    async def _runner():
121        try:
122            await coro
123        except asyncio.CancelledError:
124            if allow_cancellation:
125                return
126            raise
127    return create_task(_runner())
128
129
130@contextmanager
131def jammed_socket():
132    """
133    Opens up a random unused TCP port on localhost, then jams it.
134    """
135    socks = []
136
137    try:
138        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
139        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
140        sock.bind(('127.0.0.1', 0))
141        sock.listen(1)
142        address = sock.getsockname()
143
144        socks.append(sock)
145
146        # I don't *fully* understand why, but it takes *two* un-accepted
147        # connections to start jamming the socket.
148        for _ in range(2):
149            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150            sock.connect(address)
151            socks.append(sock)
152
153        yield address
154
155    finally:
156        for sock in socks:
157            sock.close()
158
159
160class Smoke(avocado.Test):
161
162    def setUp(self):
163        self.proto = NullProtocol()
164
165    def test__repr__(self):
166        self.assertEqual(
167            repr(self.proto),
168            "<NullProtocol runstate=IDLE>"
169        )
170
171    def testRunstate(self):
172        self.assertEqual(
173            self.proto.runstate,
174            Runstate.IDLE
175        )
176
177    def testDefaultName(self):
178        self.assertEqual(
179            self.proto.name,
180            None
181        )
182
183    def testLogger(self):
184        self.assertEqual(
185            self.proto.logger.name,
186            'qemu.qmp.protocol'
187        )
188
189    def testName(self):
190        self.proto = NullProtocol('Steve')
191
192        self.assertEqual(
193            self.proto.name,
194            'Steve'
195        )
196
197        self.assertEqual(
198            self.proto.logger.name,
199            'qemu.qmp.protocol.Steve'
200        )
201
202        self.assertEqual(
203            repr(self.proto),
204            "<NullProtocol name='Steve' runstate=IDLE>"
205        )
206
207
208class TestBase(avocado.Test):
209
210    def setUp(self):
211        self.proto = NullProtocol(type(self).__name__)
212        self.assertEqual(self.proto.runstate, Runstate.IDLE)
213        self.runstate_watcher = None
214
215    def tearDown(self):
216        self.assertEqual(self.proto.runstate, Runstate.IDLE)
217
218    async def _asyncSetUp(self):
219        pass
220
221    async def _asyncTearDown(self):
222        if self.runstate_watcher:
223            await self.runstate_watcher
224
225    @staticmethod
226    def async_test(async_test_method):
227        """
228        Decorator; adds SetUp and TearDown to async tests.
229        """
230        async def _wrapper(self, *args, **kwargs):
231            loop = asyncio.get_event_loop()
232            loop.set_debug(True)
233
234            await self._asyncSetUp()
235            await async_test_method(self, *args, **kwargs)
236            await self._asyncTearDown()
237
238        return _wrapper
239
240    # Definitions
241
242    # The states we expect a "bad" connect/accept attempt to transition through
243    BAD_CONNECTION_STATES = (
244        Runstate.CONNECTING,
245        Runstate.DISCONNECTING,
246        Runstate.IDLE,
247    )
248
249    # The states we expect a "good" session to transition through
250    GOOD_CONNECTION_STATES = (
251        Runstate.CONNECTING,
252        Runstate.RUNNING,
253        Runstate.DISCONNECTING,
254        Runstate.IDLE,
255    )
256
257    # Helpers
258
259    async def _watch_runstates(self, *states):
260        """
261        This launches a task alongside (most) tests below to confirm that
262        the sequence of runstate changes that occur is exactly as
263        anticipated.
264        """
265        async def _watcher():
266            for state in states:
267                new_state = await self.proto.runstate_changed()
268                self.assertEqual(
269                    new_state,
270                    state,
271                    msg=f"Expected state '{state.name}'",
272                )
273
274        self.runstate_watcher = create_task(_watcher())
275        # Kick the loop and force the task to block on the event.
276        await asyncio.sleep(0)
277
278
279class State(TestBase):
280
281    @TestBase.async_test
282    async def testSuperfluousDisconnect(self):
283        """
284        Test calling disconnect() while already disconnected.
285        """
286        await self._watch_runstates(
287            Runstate.DISCONNECTING,
288            Runstate.IDLE,
289        )
290        await self.proto.disconnect()
291
292
293class Connect(TestBase):
294    """
295    Tests primarily related to calling Connect().
296    """
297    async def _bad_connection(self, family: str):
298        assert family in ('INET', 'UNIX')
299
300        if family == 'INET':
301            await self.proto.connect(('127.0.0.1', 0))
302        elif family == 'UNIX':
303            await self.proto.connect('/dev/null')
304
305    async def _hanging_connection(self):
306        with jammed_socket() as addr:
307            await self.proto.connect(addr)
308
309    async def _bad_connection_test(self, family: str):
310        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
311
312        with self.assertRaises(ConnectError) as context:
313            await self._bad_connection(family)
314
315        self.assertIsInstance(context.exception.exc, OSError)
316        self.assertEqual(
317            context.exception.error_message,
318            "Failed to establish connection"
319        )
320
321    @TestBase.async_test
322    async def testBadINET(self):
323        """
324        Test an immediately rejected call to an IP target.
325        """
326        await self._bad_connection_test('INET')
327
328    @TestBase.async_test
329    async def testBadUNIX(self):
330        """
331        Test an immediately rejected call to a UNIX socket target.
332        """
333        await self._bad_connection_test('UNIX')
334
335    @TestBase.async_test
336    async def testCancellation(self):
337        """
338        Test what happens when a connection attempt is aborted.
339        """
340        # Note that accept() cannot be cancelled outright, as it isn't a task.
341        # However, we can wrap it in a task and cancel *that*.
342        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
343        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
344
345        state = await self.proto.runstate_changed()
346        self.assertEqual(state, Runstate.CONNECTING)
347
348        # This is insider baseball, but the connection attempt has
349        # yielded *just* before the actual connection attempt, so kick
350        # the loop to make sure it's truly wedged.
351        await asyncio.sleep(0)
352
353        task.cancel()
354        await task
355
356    @TestBase.async_test
357    async def testTimeout(self):
358        """
359        Test what happens when a connection attempt times out.
360        """
361        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
362        task = run_as_task(self._hanging_connection())
363
364        # More insider baseball: to improve the speed of this test while
365        # guaranteeing that the connection even gets a chance to start,
366        # verify that the connection hangs *first*, then await the
367        # result of the task with a nearly-zero timeout.
368
369        state = await self.proto.runstate_changed()
370        self.assertEqual(state, Runstate.CONNECTING)
371        await asyncio.sleep(0)
372
373        with self.assertRaises(asyncio.TimeoutError):
374            await asyncio.wait_for(task, timeout=0)
375
376    @TestBase.async_test
377    async def testRequire(self):
378        """
379        Test what happens when a connection attempt is made while CONNECTING.
380        """
381        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
382        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
383
384        state = await self.proto.runstate_changed()
385        self.assertEqual(state, Runstate.CONNECTING)
386
387        with self.assertRaises(StateError) as context:
388            await self._bad_connection('UNIX')
389
390        self.assertEqual(
391            context.exception.error_message,
392            "NullProtocol is currently connecting."
393        )
394        self.assertEqual(context.exception.state, Runstate.CONNECTING)
395        self.assertEqual(context.exception.required, Runstate.IDLE)
396
397        task.cancel()
398        await task
399
400    @TestBase.async_test
401    async def testImplicitRunstateInit(self):
402        """
403        Test what happens if we do not wait on the runstate event until
404        AFTER a connection is made, i.e., connect()/accept() themselves
405        initialize the runstate event. All of the above tests force the
406        initialization by waiting on the runstate *first*.
407        """
408        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
409
410        # Kick the loop to coerce the state change
411        await asyncio.sleep(0)
412        assert self.proto.runstate == Runstate.CONNECTING
413
414        # We already missed the transition to CONNECTING
415        await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
416
417        task.cancel()
418        await task
419
420
421class Accept(Connect):
422    """
423    All of the same tests as Connect, but using the accept() interface.
424    """
425    async def _bad_connection(self, family: str):
426        assert family in ('INET', 'UNIX')
427
428        if family == 'INET':
429            await self.proto.start_server_and_accept(('example.com', 1))
430        elif family == 'UNIX':
431            await self.proto.start_server_and_accept('/dev/null')
432
433    async def _hanging_connection(self):
434        with TemporaryDirectory(suffix='.qmp') as tmpdir:
435            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
436            await self.proto.start_server_and_accept(sock)
437
438
439class FakeSession(TestBase):
440
441    def setUp(self):
442        super().setUp()
443        self.proto.fake_session = True
444
445    async def _asyncSetUp(self):
446        await super()._asyncSetUp()
447        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
448
449    async def _asyncTearDown(self):
450        await self.proto.disconnect()
451        await super()._asyncTearDown()
452
453    ####
454
455    @TestBase.async_test
456    async def testFakeConnect(self):
457
458        """Test the full state lifecycle (via connect) with a no-op session."""
459        await self.proto.connect('/not/a/real/path')
460        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
461
462    @TestBase.async_test
463    async def testFakeAccept(self):
464        """Test the full state lifecycle (via accept) with a no-op session."""
465        await self.proto.start_server_and_accept('/not/a/real/path')
466        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
467
468    @TestBase.async_test
469    async def testFakeRecv(self):
470        """Test receiving a fake/null message."""
471        await self.proto.start_server_and_accept('/not/a/real/path')
472
473        logname = self.proto.logger.name
474        with self.assertLogs(logname, level='DEBUG') as context:
475            self.proto.trigger_input.set()
476            self.proto.trigger_input.clear()
477            await asyncio.sleep(0)  # Kick reader.
478
479        self.assertEqual(
480            context.output,
481            [f"DEBUG:{logname}:<-- None"],
482        )
483
484    @TestBase.async_test
485    async def testFakeSend(self):
486        """Test sending a fake/null message."""
487        await self.proto.start_server_and_accept('/not/a/real/path')
488
489        logname = self.proto.logger.name
490        with self.assertLogs(logname, level='DEBUG') as context:
491            # Cheat: Send a Null message to nobody.
492            await self.proto.send_msg()
493            # Kick writer; awaiting on a queue.put isn't sufficient to yield.
494            await asyncio.sleep(0)
495
496        self.assertEqual(
497            context.output,
498            [f"DEBUG:{logname}:--> None"],
499        )
500
501    async def _prod_session_api(
502            self,
503            current_state: Runstate,
504            error_message: str,
505            accept: bool = True
506    ):
507        with self.assertRaises(StateError) as context:
508            if accept:
509                await self.proto.start_server_and_accept('/not/a/real/path')
510            else:
511                await self.proto.connect('/not/a/real/path')
512
513        self.assertEqual(context.exception.error_message, error_message)
514        self.assertEqual(context.exception.state, current_state)
515        self.assertEqual(context.exception.required, Runstate.IDLE)
516
517    @TestBase.async_test
518    async def testAcceptRequireRunning(self):
519        """Test that accept() cannot be called when Runstate=RUNNING"""
520        await self.proto.start_server_and_accept('/not/a/real/path')
521
522        await self._prod_session_api(
523            Runstate.RUNNING,
524            "NullProtocol is already connected and running.",
525            accept=True,
526        )
527
528    @TestBase.async_test
529    async def testConnectRequireRunning(self):
530        """Test that connect() cannot be called when Runstate=RUNNING"""
531        await self.proto.start_server_and_accept('/not/a/real/path')
532
533        await self._prod_session_api(
534            Runstate.RUNNING,
535            "NullProtocol is already connected and running.",
536            accept=False,
537        )
538
539    @TestBase.async_test
540    async def testAcceptRequireDisconnecting(self):
541        """Test that accept() cannot be called when Runstate=DISCONNECTING"""
542        await self.proto.start_server_and_accept('/not/a/real/path')
543
544        # Cheat: force a disconnect.
545        await self.proto.simulate_disconnect()
546
547        await self._prod_session_api(
548            Runstate.DISCONNECTING,
549            ("NullProtocol is disconnecting."
550             " Call disconnect() to return to IDLE state."),
551            accept=True,
552        )
553
554    @TestBase.async_test
555    async def testConnectRequireDisconnecting(self):
556        """Test that connect() cannot be called when Runstate=DISCONNECTING"""
557        await self.proto.start_server_and_accept('/not/a/real/path')
558
559        # Cheat: force a disconnect.
560        await self.proto.simulate_disconnect()
561
562        await self._prod_session_api(
563            Runstate.DISCONNECTING,
564            ("NullProtocol is disconnecting."
565             " Call disconnect() to return to IDLE state."),
566            accept=False,
567        )
568
569
570class SimpleSession(TestBase):
571
572    def setUp(self):
573        super().setUp()
574        self.server = LineProtocol(type(self).__name__ + '-server')
575
576    async def _asyncSetUp(self):
577        await super()._asyncSetUp()
578        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
579
580    async def _asyncTearDown(self):
581        await self.proto.disconnect()
582        try:
583            await self.server.disconnect()
584        except EOFError:
585            pass
586        await super()._asyncTearDown()
587
588    @TestBase.async_test
589    async def testSmoke(self):
590        with TemporaryDirectory(suffix='.qmp') as tmpdir:
591            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
592            server_task = create_task(self.server.start_server_and_accept(sock))
593
594            # give the server a chance to start listening [...]
595            await asyncio.sleep(0)
596            await self.proto.connect(sock)
597