xref: /openbmc/qemu/python/tests/protocol.py (revision da5006445a92bb7801f54a93452fac63ca2f634c)
1a1f71b61SJohn Snowimport asyncio
2a1f71b61SJohn Snowfrom contextlib import contextmanager
3a1f71b61SJohn Snowimport os
4a1f71b61SJohn Snowimport socket
5a1f71b61SJohn Snowfrom tempfile import TemporaryDirectory
6a1f71b61SJohn Snow
7a1f71b61SJohn Snowimport avocado
8a1f71b61SJohn Snow
9*37094b6dSJohn Snowfrom qemu.qmp import ConnectError, Runstate
10*37094b6dSJohn Snowfrom qemu.qmp.protocol import AsyncProtocol, StateError
11*37094b6dSJohn Snowfrom qemu.qmp.util import asyncio_run, create_task
12a1f71b61SJohn Snow
13a1f71b61SJohn Snow
14a1f71b61SJohn Snowclass NullProtocol(AsyncProtocol[None]):
15a1f71b61SJohn Snow    """
16a1f71b61SJohn Snow    NullProtocol is a test mockup of an AsyncProtocol implementation.
17a1f71b61SJohn Snow
18a1f71b61SJohn Snow    It adds a fake_session instance variable that enables a code path
19a1f71b61SJohn Snow    that bypasses the actual connection logic, but still allows the
20a1f71b61SJohn Snow    reader/writers to start.
21a1f71b61SJohn Snow
22a1f71b61SJohn Snow    Because the message type is defined as None, an asyncio.Event named
23a1f71b61SJohn Snow    'trigger_input' is created that prohibits the reader from
24a1f71b61SJohn Snow    incessantly being able to yield None; this event can be poked to
25a1f71b61SJohn Snow    simulate an incoming message.
26a1f71b61SJohn Snow
27a1f71b61SJohn Snow    For testing symmetry with do_recv, an interface is added to "send" a
28a1f71b61SJohn Snow    Null message.
29a1f71b61SJohn Snow
30a1f71b61SJohn Snow    For testing purposes, a "simulate_disconnection" method is also
31a1f71b61SJohn Snow    added which allows us to trigger a bottom half disconnect without
32a1f71b61SJohn Snow    injecting any real errors into the reader/writer loops; in essence
33a1f71b61SJohn Snow    it performs exactly half of what disconnect() normally does.
34a1f71b61SJohn Snow    """
35a1f71b61SJohn Snow    def __init__(self, name=None):
36a1f71b61SJohn Snow        self.fake_session = False
37a1f71b61SJohn Snow        self.trigger_input: asyncio.Event
38a1f71b61SJohn Snow        super().__init__(name)
39a1f71b61SJohn Snow
40a1f71b61SJohn Snow    async def _establish_session(self):
41a1f71b61SJohn Snow        self.trigger_input = asyncio.Event()
42a1f71b61SJohn Snow        await super()._establish_session()
43a1f71b61SJohn Snow
445e9902a0SJohn Snow    async def _do_start_server(self, address, ssl=None):
4568a6cf3fSJohn Snow        if self.fake_session:
46481607c7SJohn Snow            self._accepted = asyncio.Event()
4768a6cf3fSJohn Snow            self._set_state(Runstate.CONNECTING)
4868a6cf3fSJohn Snow            await asyncio.sleep(0)
4968a6cf3fSJohn Snow        else:
505e9902a0SJohn Snow            await super()._do_start_server(address, ssl)
51a1f71b61SJohn Snow
52481607c7SJohn Snow    async def _do_accept(self):
53481607c7SJohn Snow        if self.fake_session:
54481607c7SJohn Snow            self._accepted = None
55481607c7SJohn Snow        else:
56481607c7SJohn Snow            await super()._do_accept()
57481607c7SJohn Snow
58a1f71b61SJohn Snow    async def _do_connect(self, address, ssl=None):
5968a6cf3fSJohn Snow        if self.fake_session:
6068a6cf3fSJohn Snow            self._set_state(Runstate.CONNECTING)
6168a6cf3fSJohn Snow            await asyncio.sleep(0)
6268a6cf3fSJohn Snow        else:
63a1f71b61SJohn Snow            await super()._do_connect(address, ssl)
64a1f71b61SJohn Snow
65a1f71b61SJohn Snow    async def _do_recv(self) -> None:
66a1f71b61SJohn Snow        await self.trigger_input.wait()
67a1f71b61SJohn Snow        self.trigger_input.clear()
68a1f71b61SJohn Snow
69a1f71b61SJohn Snow    def _do_send(self, msg: None) -> None:
70a1f71b61SJohn Snow        pass
71a1f71b61SJohn Snow
72a1f71b61SJohn Snow    async def send_msg(self) -> None:
73a1f71b61SJohn Snow        await self._outgoing.put(None)
74a1f71b61SJohn Snow
75a1f71b61SJohn Snow    async def simulate_disconnect(self) -> None:
76a1f71b61SJohn Snow        """
77a1f71b61SJohn Snow        Simulates a bottom-half disconnect.
78a1f71b61SJohn Snow
79a1f71b61SJohn Snow        This method schedules a disconnection but does not wait for it
80a1f71b61SJohn Snow        to complete. This is used to put the loop into the DISCONNECTING
81a1f71b61SJohn Snow        state without fully quiescing it back to IDLE. This is normally
82a1f71b61SJohn Snow        something you cannot coax AsyncProtocol to do on purpose, but it
83a1f71b61SJohn Snow        will be similar to what happens with an unhandled Exception in
84a1f71b61SJohn Snow        the reader/writer.
85a1f71b61SJohn Snow
86a1f71b61SJohn Snow        Under normal circumstances, the library design requires you to
87a1f71b61SJohn Snow        await on disconnect(), which awaits the disconnect task and
88a1f71b61SJohn Snow        returns bottom half errors as a pre-condition to allowing the
89a1f71b61SJohn Snow        loop to return back to IDLE.
90a1f71b61SJohn Snow        """
91a1f71b61SJohn Snow        self._schedule_disconnect()
92a1f71b61SJohn Snow
93a1f71b61SJohn Snow
948193b9d1SJohn Snowclass LineProtocol(AsyncProtocol[str]):
958193b9d1SJohn Snow    def __init__(self, name=None):
968193b9d1SJohn Snow        super().__init__(name)
978193b9d1SJohn Snow        self.rx_history = []
988193b9d1SJohn Snow
998193b9d1SJohn Snow    async def _do_recv(self) -> str:
1008193b9d1SJohn Snow        raw = await self._readline()
1018193b9d1SJohn Snow        msg = raw.decode()
1028193b9d1SJohn Snow        self.rx_history.append(msg)
1038193b9d1SJohn Snow        return msg
1048193b9d1SJohn Snow
1058193b9d1SJohn Snow    def _do_send(self, msg: str) -> None:
1068193b9d1SJohn Snow        assert self._writer is not None
1078193b9d1SJohn Snow        self._writer.write(msg.encode() + b'\n')
1088193b9d1SJohn Snow
1098193b9d1SJohn Snow    async def send_msg(self, msg: str) -> None:
1108193b9d1SJohn Snow        await self._outgoing.put(msg)
1118193b9d1SJohn Snow
1128193b9d1SJohn Snow
113a1f71b61SJohn Snowdef run_as_task(coro, allow_cancellation=False):
114a1f71b61SJohn Snow    """
115a1f71b61SJohn Snow    Run a given coroutine as a task.
116a1f71b61SJohn Snow
117a1f71b61SJohn Snow    Optionally, wrap it in a try..except block that allows this
118a1f71b61SJohn Snow    coroutine to be canceled gracefully.
119a1f71b61SJohn Snow    """
120a1f71b61SJohn Snow    async def _runner():
121a1f71b61SJohn Snow        try:
122a1f71b61SJohn Snow            await coro
123a1f71b61SJohn Snow        except asyncio.CancelledError:
124a1f71b61SJohn Snow            if allow_cancellation:
125a1f71b61SJohn Snow                return
126a1f71b61SJohn Snow            raise
127a1f71b61SJohn Snow    return create_task(_runner())
128a1f71b61SJohn Snow
129a1f71b61SJohn Snow
130a1f71b61SJohn Snow@contextmanager
131a1f71b61SJohn Snowdef jammed_socket():
132a1f71b61SJohn Snow    """
133a1f71b61SJohn Snow    Opens up a random unused TCP port on localhost, then jams it.
134a1f71b61SJohn Snow    """
135a1f71b61SJohn Snow    socks = []
136a1f71b61SJohn Snow
137a1f71b61SJohn Snow    try:
138a1f71b61SJohn Snow        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
139a1f71b61SJohn Snow        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
140a1f71b61SJohn Snow        sock.bind(('127.0.0.1', 0))
141a1f71b61SJohn Snow        sock.listen(1)
142a1f71b61SJohn Snow        address = sock.getsockname()
143a1f71b61SJohn Snow
144a1f71b61SJohn Snow        socks.append(sock)
145a1f71b61SJohn Snow
146a1f71b61SJohn Snow        # I don't *fully* understand why, but it takes *two* un-accepted
147a1f71b61SJohn Snow        # connections to start jamming the socket.
148a1f71b61SJohn Snow        for _ in range(2):
149a1f71b61SJohn Snow            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150a1f71b61SJohn Snow            sock.connect(address)
151a1f71b61SJohn Snow            socks.append(sock)
152a1f71b61SJohn Snow
153a1f71b61SJohn Snow        yield address
154a1f71b61SJohn Snow
155a1f71b61SJohn Snow    finally:
156a1f71b61SJohn Snow        for sock in socks:
157a1f71b61SJohn Snow            sock.close()
158a1f71b61SJohn Snow
159a1f71b61SJohn Snow
160a1f71b61SJohn Snowclass Smoke(avocado.Test):
161a1f71b61SJohn Snow
162a1f71b61SJohn Snow    def setUp(self):
163a1f71b61SJohn Snow        self.proto = NullProtocol()
164a1f71b61SJohn Snow
165a1f71b61SJohn Snow    def test__repr__(self):
166a1f71b61SJohn Snow        self.assertEqual(
167a1f71b61SJohn Snow            repr(self.proto),
168a1f71b61SJohn Snow            "<NullProtocol runstate=IDLE>"
169a1f71b61SJohn Snow        )
170a1f71b61SJohn Snow
171a1f71b61SJohn Snow    def testRunstate(self):
172a1f71b61SJohn Snow        self.assertEqual(
173a1f71b61SJohn Snow            self.proto.runstate,
174a1f71b61SJohn Snow            Runstate.IDLE
175a1f71b61SJohn Snow        )
176a1f71b61SJohn Snow
177a1f71b61SJohn Snow    def testDefaultName(self):
178a1f71b61SJohn Snow        self.assertEqual(
179a1f71b61SJohn Snow            self.proto.name,
180a1f71b61SJohn Snow            None
181a1f71b61SJohn Snow        )
182a1f71b61SJohn Snow
183a1f71b61SJohn Snow    def testLogger(self):
184a1f71b61SJohn Snow        self.assertEqual(
185a1f71b61SJohn Snow            self.proto.logger.name,
186*37094b6dSJohn Snow            'qemu.qmp.protocol'
187a1f71b61SJohn Snow        )
188a1f71b61SJohn Snow
189a1f71b61SJohn Snow    def testName(self):
190a1f71b61SJohn Snow        self.proto = NullProtocol('Steve')
191a1f71b61SJohn Snow
192a1f71b61SJohn Snow        self.assertEqual(
193a1f71b61SJohn Snow            self.proto.name,
194a1f71b61SJohn Snow            'Steve'
195a1f71b61SJohn Snow        )
196a1f71b61SJohn Snow
197a1f71b61SJohn Snow        self.assertEqual(
198a1f71b61SJohn Snow            self.proto.logger.name,
199*37094b6dSJohn Snow            'qemu.qmp.protocol.Steve'
200a1f71b61SJohn Snow        )
201a1f71b61SJohn Snow
202a1f71b61SJohn Snow        self.assertEqual(
203a1f71b61SJohn Snow            repr(self.proto),
204a1f71b61SJohn Snow            "<NullProtocol name='Steve' runstate=IDLE>"
205a1f71b61SJohn Snow        )
206a1f71b61SJohn Snow
207a1f71b61SJohn Snow
208a1f71b61SJohn Snowclass TestBase(avocado.Test):
209a1f71b61SJohn Snow
210a1f71b61SJohn Snow    def setUp(self):
211a1f71b61SJohn Snow        self.proto = NullProtocol(type(self).__name__)
212a1f71b61SJohn Snow        self.assertEqual(self.proto.runstate, Runstate.IDLE)
213a1f71b61SJohn Snow        self.runstate_watcher = None
214a1f71b61SJohn Snow
215a1f71b61SJohn Snow    def tearDown(self):
216a1f71b61SJohn Snow        self.assertEqual(self.proto.runstate, Runstate.IDLE)
217a1f71b61SJohn Snow
218a1f71b61SJohn Snow    async def _asyncSetUp(self):
219a1f71b61SJohn Snow        pass
220a1f71b61SJohn Snow
221a1f71b61SJohn Snow    async def _asyncTearDown(self):
222a1f71b61SJohn Snow        if self.runstate_watcher:
223a1f71b61SJohn Snow            await self.runstate_watcher
224a1f71b61SJohn Snow
225a1f71b61SJohn Snow    @staticmethod
226a1f71b61SJohn Snow    def async_test(async_test_method):
227a1f71b61SJohn Snow        """
228a1f71b61SJohn Snow        Decorator; adds SetUp and TearDown to async tests.
229a1f71b61SJohn Snow        """
230a1f71b61SJohn Snow        async def _wrapper(self, *args, **kwargs):
231a1f71b61SJohn Snow            loop = asyncio.get_event_loop()
232a1f71b61SJohn Snow            loop.set_debug(True)
233a1f71b61SJohn Snow
234a1f71b61SJohn Snow            await self._asyncSetUp()
235a1f71b61SJohn Snow            await async_test_method(self, *args, **kwargs)
236a1f71b61SJohn Snow            await self._asyncTearDown()
237a1f71b61SJohn Snow
238a1f71b61SJohn Snow        return _wrapper
239a1f71b61SJohn Snow
240a1f71b61SJohn Snow    # Definitions
241a1f71b61SJohn Snow
242a1f71b61SJohn Snow    # The states we expect a "bad" connect/accept attempt to transition through
243a1f71b61SJohn Snow    BAD_CONNECTION_STATES = (
244a1f71b61SJohn Snow        Runstate.CONNECTING,
245a1f71b61SJohn Snow        Runstate.DISCONNECTING,
246a1f71b61SJohn Snow        Runstate.IDLE,
247a1f71b61SJohn Snow    )
248a1f71b61SJohn Snow
249a1f71b61SJohn Snow    # The states we expect a "good" session to transition through
250a1f71b61SJohn Snow    GOOD_CONNECTION_STATES = (
251a1f71b61SJohn Snow        Runstate.CONNECTING,
252a1f71b61SJohn Snow        Runstate.RUNNING,
253a1f71b61SJohn Snow        Runstate.DISCONNECTING,
254a1f71b61SJohn Snow        Runstate.IDLE,
255a1f71b61SJohn Snow    )
256a1f71b61SJohn Snow
257a1f71b61SJohn Snow    # Helpers
258a1f71b61SJohn Snow
259a1f71b61SJohn Snow    async def _watch_runstates(self, *states):
260a1f71b61SJohn Snow        """
261a1f71b61SJohn Snow        This launches a task alongside (most) tests below to confirm that
262a1f71b61SJohn Snow        the sequence of runstate changes that occur is exactly as
263a1f71b61SJohn Snow        anticipated.
264a1f71b61SJohn Snow        """
265a1f71b61SJohn Snow        async def _watcher():
266a1f71b61SJohn Snow            for state in states:
267a1f71b61SJohn Snow                new_state = await self.proto.runstate_changed()
268a1f71b61SJohn Snow                self.assertEqual(
269a1f71b61SJohn Snow                    new_state,
270a1f71b61SJohn Snow                    state,
271a1f71b61SJohn Snow                    msg=f"Expected state '{state.name}'",
272a1f71b61SJohn Snow                )
273a1f71b61SJohn Snow
274a1f71b61SJohn Snow        self.runstate_watcher = create_task(_watcher())
275a1f71b61SJohn Snow        # Kick the loop and force the task to block on the event.
276a1f71b61SJohn Snow        await asyncio.sleep(0)
277a1f71b61SJohn Snow
278a1f71b61SJohn Snow
279a1f71b61SJohn Snowclass State(TestBase):
280a1f71b61SJohn Snow
281a1f71b61SJohn Snow    @TestBase.async_test
282a1f71b61SJohn Snow    async def testSuperfluousDisconnect(self):
283a1f71b61SJohn Snow        """
284a1f71b61SJohn Snow        Test calling disconnect() while already disconnected.
285a1f71b61SJohn Snow        """
286a1f71b61SJohn Snow        await self._watch_runstates(
287a1f71b61SJohn Snow            Runstate.DISCONNECTING,
288a1f71b61SJohn Snow            Runstate.IDLE,
289a1f71b61SJohn Snow        )
290a1f71b61SJohn Snow        await self.proto.disconnect()
291a1f71b61SJohn Snow
292a1f71b61SJohn Snow
293a1f71b61SJohn Snowclass Connect(TestBase):
294a1f71b61SJohn Snow    """
295a1f71b61SJohn Snow    Tests primarily related to calling Connect().
296a1f71b61SJohn Snow    """
297a1f71b61SJohn Snow    async def _bad_connection(self, family: str):
298a1f71b61SJohn Snow        assert family in ('INET', 'UNIX')
299a1f71b61SJohn Snow
300a1f71b61SJohn Snow        if family == 'INET':
301a1f71b61SJohn Snow            await self.proto.connect(('127.0.0.1', 0))
302a1f71b61SJohn Snow        elif family == 'UNIX':
303a1f71b61SJohn Snow            await self.proto.connect('/dev/null')
304a1f71b61SJohn Snow
305a1f71b61SJohn Snow    async def _hanging_connection(self):
306a1f71b61SJohn Snow        with jammed_socket() as addr:
307a1f71b61SJohn Snow            await self.proto.connect(addr)
308a1f71b61SJohn Snow
309a1f71b61SJohn Snow    async def _bad_connection_test(self, family: str):
310a1f71b61SJohn Snow        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
311a1f71b61SJohn Snow
312a1f71b61SJohn Snow        with self.assertRaises(ConnectError) as context:
313a1f71b61SJohn Snow            await self._bad_connection(family)
314a1f71b61SJohn Snow
315a1f71b61SJohn Snow        self.assertIsInstance(context.exception.exc, OSError)
316a1f71b61SJohn Snow        self.assertEqual(
317a1f71b61SJohn Snow            context.exception.error_message,
318a1f71b61SJohn Snow            "Failed to establish connection"
319a1f71b61SJohn Snow        )
320a1f71b61SJohn Snow
321a1f71b61SJohn Snow    @TestBase.async_test
322a1f71b61SJohn Snow    async def testBadINET(self):
323a1f71b61SJohn Snow        """
324a1f71b61SJohn Snow        Test an immediately rejected call to an IP target.
325a1f71b61SJohn Snow        """
326a1f71b61SJohn Snow        await self._bad_connection_test('INET')
327a1f71b61SJohn Snow
328a1f71b61SJohn Snow    @TestBase.async_test
329a1f71b61SJohn Snow    async def testBadUNIX(self):
330a1f71b61SJohn Snow        """
331a1f71b61SJohn Snow        Test an immediately rejected call to a UNIX socket target.
332a1f71b61SJohn Snow        """
333a1f71b61SJohn Snow        await self._bad_connection_test('UNIX')
334a1f71b61SJohn Snow
335a1f71b61SJohn Snow    @TestBase.async_test
336a1f71b61SJohn Snow    async def testCancellation(self):
337a1f71b61SJohn Snow        """
338a1f71b61SJohn Snow        Test what happens when a connection attempt is aborted.
339a1f71b61SJohn Snow        """
340a1f71b61SJohn Snow        # Note that accept() cannot be cancelled outright, as it isn't a task.
341a1f71b61SJohn Snow        # However, we can wrap it in a task and cancel *that*.
342a1f71b61SJohn Snow        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
343a1f71b61SJohn Snow        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
344a1f71b61SJohn Snow
345a1f71b61SJohn Snow        state = await self.proto.runstate_changed()
346a1f71b61SJohn Snow        self.assertEqual(state, Runstate.CONNECTING)
347a1f71b61SJohn Snow
348a1f71b61SJohn Snow        # This is insider baseball, but the connection attempt has
349a1f71b61SJohn Snow        # yielded *just* before the actual connection attempt, so kick
350a1f71b61SJohn Snow        # the loop to make sure it's truly wedged.
351a1f71b61SJohn Snow        await asyncio.sleep(0)
352a1f71b61SJohn Snow
353a1f71b61SJohn Snow        task.cancel()
354a1f71b61SJohn Snow        await task
355a1f71b61SJohn Snow
356a1f71b61SJohn Snow    @TestBase.async_test
357a1f71b61SJohn Snow    async def testTimeout(self):
358a1f71b61SJohn Snow        """
359a1f71b61SJohn Snow        Test what happens when a connection attempt times out.
360a1f71b61SJohn Snow        """
361a1f71b61SJohn Snow        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
362a1f71b61SJohn Snow        task = run_as_task(self._hanging_connection())
363a1f71b61SJohn Snow
364a1f71b61SJohn Snow        # More insider baseball: to improve the speed of this test while
365a1f71b61SJohn Snow        # guaranteeing that the connection even gets a chance to start,
366a1f71b61SJohn Snow        # verify that the connection hangs *first*, then await the
367a1f71b61SJohn Snow        # result of the task with a nearly-zero timeout.
368a1f71b61SJohn Snow
369a1f71b61SJohn Snow        state = await self.proto.runstate_changed()
370a1f71b61SJohn Snow        self.assertEqual(state, Runstate.CONNECTING)
371a1f71b61SJohn Snow        await asyncio.sleep(0)
372a1f71b61SJohn Snow
373a1f71b61SJohn Snow        with self.assertRaises(asyncio.TimeoutError):
374a1f71b61SJohn Snow            await asyncio.wait_for(task, timeout=0)
375a1f71b61SJohn Snow
376a1f71b61SJohn Snow    @TestBase.async_test
377a1f71b61SJohn Snow    async def testRequire(self):
378a1f71b61SJohn Snow        """
379a1f71b61SJohn Snow        Test what happens when a connection attempt is made while CONNECTING.
380a1f71b61SJohn Snow        """
381a1f71b61SJohn Snow        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
382a1f71b61SJohn Snow        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
383a1f71b61SJohn Snow
384a1f71b61SJohn Snow        state = await self.proto.runstate_changed()
385a1f71b61SJohn Snow        self.assertEqual(state, Runstate.CONNECTING)
386a1f71b61SJohn Snow
387a1f71b61SJohn Snow        with self.assertRaises(StateError) as context:
388a1f71b61SJohn Snow            await self._bad_connection('UNIX')
389a1f71b61SJohn Snow
390a1f71b61SJohn Snow        self.assertEqual(
391a1f71b61SJohn Snow            context.exception.error_message,
392a1f71b61SJohn Snow            "NullProtocol is currently connecting."
393a1f71b61SJohn Snow        )
394a1f71b61SJohn Snow        self.assertEqual(context.exception.state, Runstate.CONNECTING)
395a1f71b61SJohn Snow        self.assertEqual(context.exception.required, Runstate.IDLE)
396a1f71b61SJohn Snow
397a1f71b61SJohn Snow        task.cancel()
398a1f71b61SJohn Snow        await task
399a1f71b61SJohn Snow
400a1f71b61SJohn Snow    @TestBase.async_test
401a1f71b61SJohn Snow    async def testImplicitRunstateInit(self):
402a1f71b61SJohn Snow        """
403a1f71b61SJohn Snow        Test what happens if we do not wait on the runstate event until
404a1f71b61SJohn Snow        AFTER a connection is made, i.e., connect()/accept() themselves
405a1f71b61SJohn Snow        initialize the runstate event. All of the above tests force the
406a1f71b61SJohn Snow        initialization by waiting on the runstate *first*.
407a1f71b61SJohn Snow        """
408a1f71b61SJohn Snow        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
409a1f71b61SJohn Snow
410a1f71b61SJohn Snow        # Kick the loop to coerce the state change
411a1f71b61SJohn Snow        await asyncio.sleep(0)
412a1f71b61SJohn Snow        assert self.proto.runstate == Runstate.CONNECTING
413a1f71b61SJohn Snow
414a1f71b61SJohn Snow        # We already missed the transition to CONNECTING
415a1f71b61SJohn Snow        await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
416a1f71b61SJohn Snow
417a1f71b61SJohn Snow        task.cancel()
418a1f71b61SJohn Snow        await task
419a1f71b61SJohn Snow
420a1f71b61SJohn Snow
421a1f71b61SJohn Snowclass Accept(Connect):
422a1f71b61SJohn Snow    """
423a1f71b61SJohn Snow    All of the same tests as Connect, but using the accept() interface.
424a1f71b61SJohn Snow    """
425a1f71b61SJohn Snow    async def _bad_connection(self, family: str):
426a1f71b61SJohn Snow        assert family in ('INET', 'UNIX')
427a1f71b61SJohn Snow
428a1f71b61SJohn Snow        if family == 'INET':
4290ba4e76bSJohn Snow            await self.proto.start_server_and_accept(('example.com', 1))
430a1f71b61SJohn Snow        elif family == 'UNIX':
4310ba4e76bSJohn Snow            await self.proto.start_server_and_accept('/dev/null')
432a1f71b61SJohn Snow
433a1f71b61SJohn Snow    async def _hanging_connection(self):
434*37094b6dSJohn Snow        with TemporaryDirectory(suffix='.qmp') as tmpdir:
435a1f71b61SJohn Snow            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
4360ba4e76bSJohn Snow            await self.proto.start_server_and_accept(sock)
437a1f71b61SJohn Snow
438a1f71b61SJohn Snow
439a1f71b61SJohn Snowclass FakeSession(TestBase):
440a1f71b61SJohn Snow
441a1f71b61SJohn Snow    def setUp(self):
442a1f71b61SJohn Snow        super().setUp()
443a1f71b61SJohn Snow        self.proto.fake_session = True
444a1f71b61SJohn Snow
445a1f71b61SJohn Snow    async def _asyncSetUp(self):
446a1f71b61SJohn Snow        await super()._asyncSetUp()
447a1f71b61SJohn Snow        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
448a1f71b61SJohn Snow
449a1f71b61SJohn Snow    async def _asyncTearDown(self):
450a1f71b61SJohn Snow        await self.proto.disconnect()
451a1f71b61SJohn Snow        await super()._asyncTearDown()
452a1f71b61SJohn Snow
453a1f71b61SJohn Snow    ####
454a1f71b61SJohn Snow
455a1f71b61SJohn Snow    @TestBase.async_test
456a1f71b61SJohn Snow    async def testFakeConnect(self):
457a1f71b61SJohn Snow
458a1f71b61SJohn Snow        """Test the full state lifecycle (via connect) with a no-op session."""
459a1f71b61SJohn Snow        await self.proto.connect('/not/a/real/path')
460a1f71b61SJohn Snow        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
461a1f71b61SJohn Snow
462a1f71b61SJohn Snow    @TestBase.async_test
463a1f71b61SJohn Snow    async def testFakeAccept(self):
464a1f71b61SJohn Snow        """Test the full state lifecycle (via accept) with a no-op session."""
4650ba4e76bSJohn Snow        await self.proto.start_server_and_accept('/not/a/real/path')
466a1f71b61SJohn Snow        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
467a1f71b61SJohn Snow
468a1f71b61SJohn Snow    @TestBase.async_test
469a1f71b61SJohn Snow    async def testFakeRecv(self):
470a1f71b61SJohn Snow        """Test receiving a fake/null message."""
4710ba4e76bSJohn Snow        await self.proto.start_server_and_accept('/not/a/real/path')
472a1f71b61SJohn Snow
473a1f71b61SJohn Snow        logname = self.proto.logger.name
474a1f71b61SJohn Snow        with self.assertLogs(logname, level='DEBUG') as context:
475a1f71b61SJohn Snow            self.proto.trigger_input.set()
476a1f71b61SJohn Snow            self.proto.trigger_input.clear()
477a1f71b61SJohn Snow            await asyncio.sleep(0)  # Kick reader.
478a1f71b61SJohn Snow
479a1f71b61SJohn Snow        self.assertEqual(
480a1f71b61SJohn Snow            context.output,
481a1f71b61SJohn Snow            [f"DEBUG:{logname}:<-- None"],
482a1f71b61SJohn Snow        )
483a1f71b61SJohn Snow
484a1f71b61SJohn Snow    @TestBase.async_test
485a1f71b61SJohn Snow    async def testFakeSend(self):
486a1f71b61SJohn Snow        """Test sending a fake/null message."""
4870ba4e76bSJohn Snow        await self.proto.start_server_and_accept('/not/a/real/path')
488a1f71b61SJohn Snow
489a1f71b61SJohn Snow        logname = self.proto.logger.name
490a1f71b61SJohn Snow        with self.assertLogs(logname, level='DEBUG') as context:
491a1f71b61SJohn Snow            # Cheat: Send a Null message to nobody.
492a1f71b61SJohn Snow            await self.proto.send_msg()
493a1f71b61SJohn Snow            # Kick writer; awaiting on a queue.put isn't sufficient to yield.
494a1f71b61SJohn Snow            await asyncio.sleep(0)
495a1f71b61SJohn Snow
496a1f71b61SJohn Snow        self.assertEqual(
497a1f71b61SJohn Snow            context.output,
498a1f71b61SJohn Snow            [f"DEBUG:{logname}:--> None"],
499a1f71b61SJohn Snow        )
500a1f71b61SJohn Snow
501a1f71b61SJohn Snow    async def _prod_session_api(
502a1f71b61SJohn Snow            self,
503a1f71b61SJohn Snow            current_state: Runstate,
504a1f71b61SJohn Snow            error_message: str,
505a1f71b61SJohn Snow            accept: bool = True
506a1f71b61SJohn Snow    ):
507a1f71b61SJohn Snow        with self.assertRaises(StateError) as context:
508a1f71b61SJohn Snow            if accept:
5090ba4e76bSJohn Snow                await self.proto.start_server_and_accept('/not/a/real/path')
510a1f71b61SJohn Snow            else:
511a1f71b61SJohn Snow                await self.proto.connect('/not/a/real/path')
512a1f71b61SJohn Snow
513a1f71b61SJohn Snow        self.assertEqual(context.exception.error_message, error_message)
514a1f71b61SJohn Snow        self.assertEqual(context.exception.state, current_state)
515a1f71b61SJohn Snow        self.assertEqual(context.exception.required, Runstate.IDLE)
516a1f71b61SJohn Snow
517a1f71b61SJohn Snow    @TestBase.async_test
518a1f71b61SJohn Snow    async def testAcceptRequireRunning(self):
519a1f71b61SJohn Snow        """Test that accept() cannot be called when Runstate=RUNNING"""
5200ba4e76bSJohn Snow        await self.proto.start_server_and_accept('/not/a/real/path')
521a1f71b61SJohn Snow
522a1f71b61SJohn Snow        await self._prod_session_api(
523a1f71b61SJohn Snow            Runstate.RUNNING,
524a1f71b61SJohn Snow            "NullProtocol is already connected and running.",
525a1f71b61SJohn Snow            accept=True,
526a1f71b61SJohn Snow        )
527a1f71b61SJohn Snow
528a1f71b61SJohn Snow    @TestBase.async_test
529a1f71b61SJohn Snow    async def testConnectRequireRunning(self):
530a1f71b61SJohn Snow        """Test that connect() cannot be called when Runstate=RUNNING"""
5310ba4e76bSJohn Snow        await self.proto.start_server_and_accept('/not/a/real/path')
532a1f71b61SJohn Snow
533a1f71b61SJohn Snow        await self._prod_session_api(
534a1f71b61SJohn Snow            Runstate.RUNNING,
535a1f71b61SJohn Snow            "NullProtocol is already connected and running.",
536a1f71b61SJohn Snow            accept=False,
537a1f71b61SJohn Snow        )
538a1f71b61SJohn Snow
539a1f71b61SJohn Snow    @TestBase.async_test
540a1f71b61SJohn Snow    async def testAcceptRequireDisconnecting(self):
541a1f71b61SJohn Snow        """Test that accept() cannot be called when Runstate=DISCONNECTING"""
5420ba4e76bSJohn Snow        await self.proto.start_server_and_accept('/not/a/real/path')
543a1f71b61SJohn Snow
544a1f71b61SJohn Snow        # Cheat: force a disconnect.
545a1f71b61SJohn Snow        await self.proto.simulate_disconnect()
546a1f71b61SJohn Snow
547a1f71b61SJohn Snow        await self._prod_session_api(
548a1f71b61SJohn Snow            Runstate.DISCONNECTING,
549a1f71b61SJohn Snow            ("NullProtocol is disconnecting."
550a1f71b61SJohn Snow             " Call disconnect() to return to IDLE state."),
551a1f71b61SJohn Snow            accept=True,
552a1f71b61SJohn Snow        )
553a1f71b61SJohn Snow
554a1f71b61SJohn Snow    @TestBase.async_test
555a1f71b61SJohn Snow    async def testConnectRequireDisconnecting(self):
556a1f71b61SJohn Snow        """Test that connect() cannot be called when Runstate=DISCONNECTING"""
5570ba4e76bSJohn Snow        await self.proto.start_server_and_accept('/not/a/real/path')
558a1f71b61SJohn Snow
559a1f71b61SJohn Snow        # Cheat: force a disconnect.
560a1f71b61SJohn Snow        await self.proto.simulate_disconnect()
561a1f71b61SJohn Snow
562a1f71b61SJohn Snow        await self._prod_session_api(
563a1f71b61SJohn Snow            Runstate.DISCONNECTING,
564a1f71b61SJohn Snow            ("NullProtocol is disconnecting."
565a1f71b61SJohn Snow             " Call disconnect() to return to IDLE state."),
566a1f71b61SJohn Snow            accept=False,
567a1f71b61SJohn Snow        )
5688193b9d1SJohn Snow
5698193b9d1SJohn Snow
5708193b9d1SJohn Snowclass SimpleSession(TestBase):
5718193b9d1SJohn Snow
5728193b9d1SJohn Snow    def setUp(self):
5738193b9d1SJohn Snow        super().setUp()
5748193b9d1SJohn Snow        self.server = LineProtocol(type(self).__name__ + '-server')
5758193b9d1SJohn Snow
5768193b9d1SJohn Snow    async def _asyncSetUp(self):
5778193b9d1SJohn Snow        await super()._asyncSetUp()
5788193b9d1SJohn Snow        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
5798193b9d1SJohn Snow
5808193b9d1SJohn Snow    async def _asyncTearDown(self):
5818193b9d1SJohn Snow        await self.proto.disconnect()
5828193b9d1SJohn Snow        try:
5838193b9d1SJohn Snow            await self.server.disconnect()
5848193b9d1SJohn Snow        except EOFError:
5858193b9d1SJohn Snow            pass
5868193b9d1SJohn Snow        await super()._asyncTearDown()
5878193b9d1SJohn Snow
5888193b9d1SJohn Snow    @TestBase.async_test
5898193b9d1SJohn Snow    async def testSmoke(self):
590*37094b6dSJohn Snow        with TemporaryDirectory(suffix='.qmp') as tmpdir:
5918193b9d1SJohn Snow            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
5920ba4e76bSJohn Snow            server_task = create_task(self.server.start_server_and_accept(sock))
5938193b9d1SJohn Snow
5948193b9d1SJohn Snow            # give the server a chance to start listening [...]
5958193b9d1SJohn Snow            await asyncio.sleep(0)
5968193b9d1SJohn Snow            await self.proto.connect(sock)
597