1import asyncio 2from contextlib import contextmanager 3import os 4import socket 5from tempfile import TemporaryDirectory 6 7import avocado 8 9from qemu.qmp import ConnectError, Runstate 10from qemu.qmp.protocol import AsyncProtocol, StateError 11 12 13class NullProtocol(AsyncProtocol[None]): 14 """ 15 NullProtocol is a test mockup of an AsyncProtocol implementation. 16 17 It adds a fake_session instance variable that enables a code path 18 that bypasses the actual connection logic, but still allows the 19 reader/writers to start. 20 21 Because the message type is defined as None, an asyncio.Event named 22 'trigger_input' is created that prohibits the reader from 23 incessantly being able to yield None; this event can be poked to 24 simulate an incoming message. 25 26 For testing symmetry with do_recv, an interface is added to "send" a 27 Null message. 28 29 For testing purposes, a "simulate_disconnection" method is also 30 added which allows us to trigger a bottom half disconnect without 31 injecting any real errors into the reader/writer loops; in essence 32 it performs exactly half of what disconnect() normally does. 33 """ 34 def __init__(self, name=None): 35 self.fake_session = False 36 self.trigger_input: asyncio.Event 37 super().__init__(name) 38 39 async def _establish_session(self): 40 self.trigger_input = asyncio.Event() 41 await super()._establish_session() 42 43 async def _do_start_server(self, address, ssl=None): 44 if self.fake_session: 45 self._accepted = asyncio.Event() 46 self._set_state(Runstate.CONNECTING) 47 await asyncio.sleep(0) 48 else: 49 await super()._do_start_server(address, ssl) 50 51 async def _do_accept(self): 52 if self.fake_session: 53 self._accepted = None 54 else: 55 await super()._do_accept() 56 57 async def _do_connect(self, address, ssl=None): 58 if self.fake_session: 59 self._set_state(Runstate.CONNECTING) 60 await asyncio.sleep(0) 61 else: 62 await super()._do_connect(address, ssl) 63 64 async def _do_recv(self) -> None: 65 await self.trigger_input.wait() 66 self.trigger_input.clear() 67 68 def _do_send(self, msg: None) -> None: 69 pass 70 71 async def send_msg(self) -> None: 72 await self._outgoing.put(None) 73 74 async def simulate_disconnect(self) -> None: 75 """ 76 Simulates a bottom-half disconnect. 77 78 This method schedules a disconnection but does not wait for it 79 to complete. This is used to put the loop into the DISCONNECTING 80 state without fully quiescing it back to IDLE. This is normally 81 something you cannot coax AsyncProtocol to do on purpose, but it 82 will be similar to what happens with an unhandled Exception in 83 the reader/writer. 84 85 Under normal circumstances, the library design requires you to 86 await on disconnect(), which awaits the disconnect task and 87 returns bottom half errors as a pre-condition to allowing the 88 loop to return back to IDLE. 89 """ 90 self._schedule_disconnect() 91 92 93class LineProtocol(AsyncProtocol[str]): 94 def __init__(self, name=None): 95 super().__init__(name) 96 self.rx_history = [] 97 98 async def _do_recv(self) -> str: 99 raw = await self._readline() 100 msg = raw.decode() 101 self.rx_history.append(msg) 102 return msg 103 104 def _do_send(self, msg: str) -> None: 105 assert self._writer is not None 106 self._writer.write(msg.encode() + b'\n') 107 108 async def send_msg(self, msg: str) -> None: 109 await self._outgoing.put(msg) 110 111 112def run_as_task(coro, allow_cancellation=False): 113 """ 114 Run a given coroutine as a task. 115 116 Optionally, wrap it in a try..except block that allows this 117 coroutine to be canceled gracefully. 118 """ 119 async def _runner(): 120 try: 121 await coro 122 except asyncio.CancelledError: 123 if allow_cancellation: 124 return 125 raise 126 return asyncio.create_task(_runner()) 127 128 129@contextmanager 130def jammed_socket(): 131 """ 132 Opens up a random unused TCP port on localhost, then jams it. 133 """ 134 socks = [] 135 136 try: 137 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 138 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 139 sock.bind(('127.0.0.1', 0)) 140 sock.listen(1) 141 address = sock.getsockname() 142 143 socks.append(sock) 144 145 # I don't *fully* understand why, but it takes *two* un-accepted 146 # connections to start jamming the socket. 147 for _ in range(2): 148 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 149 sock.connect(address) 150 socks.append(sock) 151 152 yield address 153 154 finally: 155 for sock in socks: 156 sock.close() 157 158 159class Smoke(avocado.Test): 160 161 def setUp(self): 162 self.proto = NullProtocol() 163 164 def test__repr__(self): 165 self.assertEqual( 166 repr(self.proto), 167 "<NullProtocol runstate=IDLE>" 168 ) 169 170 def testRunstate(self): 171 self.assertEqual( 172 self.proto.runstate, 173 Runstate.IDLE 174 ) 175 176 def testDefaultName(self): 177 self.assertEqual( 178 self.proto.name, 179 None 180 ) 181 182 def testLogger(self): 183 self.assertEqual( 184 self.proto.logger.name, 185 'qemu.qmp.protocol' 186 ) 187 188 def testName(self): 189 self.proto = NullProtocol('Steve') 190 191 self.assertEqual( 192 self.proto.name, 193 'Steve' 194 ) 195 196 self.assertEqual( 197 self.proto.logger.name, 198 'qemu.qmp.protocol.Steve' 199 ) 200 201 self.assertEqual( 202 repr(self.proto), 203 "<NullProtocol name='Steve' runstate=IDLE>" 204 ) 205 206 207class TestBase(avocado.Test): 208 209 def setUp(self): 210 self.proto = NullProtocol(type(self).__name__) 211 self.assertEqual(self.proto.runstate, Runstate.IDLE) 212 self.runstate_watcher = None 213 214 def tearDown(self): 215 self.assertEqual(self.proto.runstate, Runstate.IDLE) 216 217 async def _asyncSetUp(self): 218 pass 219 220 async def _asyncTearDown(self): 221 if self.runstate_watcher: 222 await self.runstate_watcher 223 224 @staticmethod 225 def async_test(async_test_method): 226 """ 227 Decorator; adds SetUp and TearDown to async tests. 228 """ 229 async def _wrapper(self, *args, **kwargs): 230 loop = asyncio.get_running_loop() 231 loop.set_debug(True) 232 233 await self._asyncSetUp() 234 await async_test_method(self, *args, **kwargs) 235 await self._asyncTearDown() 236 237 return _wrapper 238 239 # Definitions 240 241 # The states we expect a "bad" connect/accept attempt to transition through 242 BAD_CONNECTION_STATES = ( 243 Runstate.CONNECTING, 244 Runstate.DISCONNECTING, 245 Runstate.IDLE, 246 ) 247 248 # The states we expect a "good" session to transition through 249 GOOD_CONNECTION_STATES = ( 250 Runstate.CONNECTING, 251 Runstate.RUNNING, 252 Runstate.DISCONNECTING, 253 Runstate.IDLE, 254 ) 255 256 # Helpers 257 258 async def _watch_runstates(self, *states): 259 """ 260 This launches a task alongside (most) tests below to confirm that 261 the sequence of runstate changes that occur is exactly as 262 anticipated. 263 """ 264 async def _watcher(): 265 for state in states: 266 new_state = await self.proto.runstate_changed() 267 self.assertEqual( 268 new_state, 269 state, 270 msg=f"Expected state '{state.name}'", 271 ) 272 273 self.runstate_watcher = asyncio.create_task(_watcher()) 274 # Kick the loop and force the task to block on the event. 275 await asyncio.sleep(0) 276 277 278class State(TestBase): 279 280 @TestBase.async_test 281 async def testSuperfluousDisconnect(self): 282 """ 283 Test calling disconnect() while already disconnected. 284 """ 285 await self._watch_runstates( 286 Runstate.DISCONNECTING, 287 Runstate.IDLE, 288 ) 289 await self.proto.disconnect() 290 291 292class Connect(TestBase): 293 """ 294 Tests primarily related to calling Connect(). 295 """ 296 async def _bad_connection(self, family: str): 297 assert family in ('INET', 'UNIX') 298 299 if family == 'INET': 300 await self.proto.connect(('127.0.0.1', 0)) 301 elif family == 'UNIX': 302 await self.proto.connect('/dev/null') 303 304 async def _hanging_connection(self): 305 with jammed_socket() as addr: 306 await self.proto.connect(addr) 307 308 async def _bad_connection_test(self, family: str): 309 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 310 311 with self.assertRaises(ConnectError) as context: 312 await self._bad_connection(family) 313 314 self.assertIsInstance(context.exception.exc, OSError) 315 self.assertEqual( 316 context.exception.error_message, 317 "Failed to establish connection" 318 ) 319 320 @TestBase.async_test 321 async def testBadINET(self): 322 """ 323 Test an immediately rejected call to an IP target. 324 """ 325 await self._bad_connection_test('INET') 326 327 @TestBase.async_test 328 async def testBadUNIX(self): 329 """ 330 Test an immediately rejected call to a UNIX socket target. 331 """ 332 await self._bad_connection_test('UNIX') 333 334 @TestBase.async_test 335 async def testCancellation(self): 336 """ 337 Test what happens when a connection attempt is aborted. 338 """ 339 # Note that accept() cannot be cancelled outright, as it isn't a task. 340 # However, we can wrap it in a task and cancel *that*. 341 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 342 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 343 344 state = await self.proto.runstate_changed() 345 self.assertEqual(state, Runstate.CONNECTING) 346 347 # This is insider baseball, but the connection attempt has 348 # yielded *just* before the actual connection attempt, so kick 349 # the loop to make sure it's truly wedged. 350 await asyncio.sleep(0) 351 352 task.cancel() 353 await task 354 355 @TestBase.async_test 356 async def testTimeout(self): 357 """ 358 Test what happens when a connection attempt times out. 359 """ 360 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 361 task = run_as_task(self._hanging_connection()) 362 363 # More insider baseball: to improve the speed of this test while 364 # guaranteeing that the connection even gets a chance to start, 365 # verify that the connection hangs *first*, then await the 366 # result of the task with a nearly-zero timeout. 367 368 state = await self.proto.runstate_changed() 369 self.assertEqual(state, Runstate.CONNECTING) 370 await asyncio.sleep(0) 371 372 with self.assertRaises(asyncio.TimeoutError): 373 await asyncio.wait_for(task, timeout=0) 374 375 @TestBase.async_test 376 async def testRequire(self): 377 """ 378 Test what happens when a connection attempt is made while CONNECTING. 379 """ 380 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 381 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 382 383 state = await self.proto.runstate_changed() 384 self.assertEqual(state, Runstate.CONNECTING) 385 386 with self.assertRaises(StateError) as context: 387 await self._bad_connection('UNIX') 388 389 self.assertEqual( 390 context.exception.error_message, 391 "NullProtocol is currently connecting." 392 ) 393 self.assertEqual(context.exception.state, Runstate.CONNECTING) 394 self.assertEqual(context.exception.required, Runstate.IDLE) 395 396 task.cancel() 397 await task 398 399 @TestBase.async_test 400 async def testImplicitRunstateInit(self): 401 """ 402 Test what happens if we do not wait on the runstate event until 403 AFTER a connection is made, i.e., connect()/accept() themselves 404 initialize the runstate event. All of the above tests force the 405 initialization by waiting on the runstate *first*. 406 """ 407 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 408 409 # Kick the loop to coerce the state change 410 await asyncio.sleep(0) 411 assert self.proto.runstate == Runstate.CONNECTING 412 413 # We already missed the transition to CONNECTING 414 await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE) 415 416 task.cancel() 417 await task 418 419 420class Accept(Connect): 421 """ 422 All of the same tests as Connect, but using the accept() interface. 423 """ 424 async def _bad_connection(self, family: str): 425 assert family in ('INET', 'UNIX') 426 427 if family == 'INET': 428 await self.proto.start_server_and_accept(('example.com', 1)) 429 elif family == 'UNIX': 430 await self.proto.start_server_and_accept('/dev/null') 431 432 async def _hanging_connection(self): 433 with TemporaryDirectory(suffix='.qmp') as tmpdir: 434 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 435 await self.proto.start_server_and_accept(sock) 436 437 438class FakeSession(TestBase): 439 440 def setUp(self): 441 super().setUp() 442 self.proto.fake_session = True 443 444 async def _asyncSetUp(self): 445 await super()._asyncSetUp() 446 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 447 448 async def _asyncTearDown(self): 449 await self.proto.disconnect() 450 await super()._asyncTearDown() 451 452 #### 453 454 @TestBase.async_test 455 async def testFakeConnect(self): 456 457 """Test the full state lifecycle (via connect) with a no-op session.""" 458 await self.proto.connect('/not/a/real/path') 459 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 460 461 @TestBase.async_test 462 async def testFakeAccept(self): 463 """Test the full state lifecycle (via accept) with a no-op session.""" 464 await self.proto.start_server_and_accept('/not/a/real/path') 465 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 466 467 @TestBase.async_test 468 async def testFakeRecv(self): 469 """Test receiving a fake/null message.""" 470 await self.proto.start_server_and_accept('/not/a/real/path') 471 472 logname = self.proto.logger.name 473 with self.assertLogs(logname, level='DEBUG') as context: 474 self.proto.trigger_input.set() 475 self.proto.trigger_input.clear() 476 await asyncio.sleep(0) # Kick reader. 477 478 self.assertEqual( 479 context.output, 480 [f"DEBUG:{logname}:<-- None"], 481 ) 482 483 @TestBase.async_test 484 async def testFakeSend(self): 485 """Test sending a fake/null message.""" 486 await self.proto.start_server_and_accept('/not/a/real/path') 487 488 logname = self.proto.logger.name 489 with self.assertLogs(logname, level='DEBUG') as context: 490 # Cheat: Send a Null message to nobody. 491 await self.proto.send_msg() 492 # Kick writer; awaiting on a queue.put isn't sufficient to yield. 493 await asyncio.sleep(0) 494 495 self.assertEqual( 496 context.output, 497 [f"DEBUG:{logname}:--> None"], 498 ) 499 500 async def _prod_session_api( 501 self, 502 current_state: Runstate, 503 error_message: str, 504 accept: bool = True 505 ): 506 with self.assertRaises(StateError) as context: 507 if accept: 508 await self.proto.start_server_and_accept('/not/a/real/path') 509 else: 510 await self.proto.connect('/not/a/real/path') 511 512 self.assertEqual(context.exception.error_message, error_message) 513 self.assertEqual(context.exception.state, current_state) 514 self.assertEqual(context.exception.required, Runstate.IDLE) 515 516 @TestBase.async_test 517 async def testAcceptRequireRunning(self): 518 """Test that accept() cannot be called when Runstate=RUNNING""" 519 await self.proto.start_server_and_accept('/not/a/real/path') 520 521 await self._prod_session_api( 522 Runstate.RUNNING, 523 "NullProtocol is already connected and running.", 524 accept=True, 525 ) 526 527 @TestBase.async_test 528 async def testConnectRequireRunning(self): 529 """Test that connect() cannot be called when Runstate=RUNNING""" 530 await self.proto.start_server_and_accept('/not/a/real/path') 531 532 await self._prod_session_api( 533 Runstate.RUNNING, 534 "NullProtocol is already connected and running.", 535 accept=False, 536 ) 537 538 @TestBase.async_test 539 async def testAcceptRequireDisconnecting(self): 540 """Test that accept() cannot be called when Runstate=DISCONNECTING""" 541 await self.proto.start_server_and_accept('/not/a/real/path') 542 543 # Cheat: force a disconnect. 544 await self.proto.simulate_disconnect() 545 546 await self._prod_session_api( 547 Runstate.DISCONNECTING, 548 ("NullProtocol is disconnecting." 549 " Call disconnect() to return to IDLE state."), 550 accept=True, 551 ) 552 553 @TestBase.async_test 554 async def testConnectRequireDisconnecting(self): 555 """Test that connect() cannot be called when Runstate=DISCONNECTING""" 556 await self.proto.start_server_and_accept('/not/a/real/path') 557 558 # Cheat: force a disconnect. 559 await self.proto.simulate_disconnect() 560 561 await self._prod_session_api( 562 Runstate.DISCONNECTING, 563 ("NullProtocol is disconnecting." 564 " Call disconnect() to return to IDLE state."), 565 accept=False, 566 ) 567 568 569class SimpleSession(TestBase): 570 571 def setUp(self): 572 super().setUp() 573 self.server = LineProtocol(type(self).__name__ + '-server') 574 575 async def _asyncSetUp(self): 576 await super()._asyncSetUp() 577 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 578 579 async def _asyncTearDown(self): 580 await self.proto.disconnect() 581 try: 582 await self.server.disconnect() 583 except EOFError: 584 pass 585 await super()._asyncTearDown() 586 587 @TestBase.async_test 588 async def testSmoke(self): 589 with TemporaryDirectory(suffix='.qmp') as tmpdir: 590 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 591 server_task = asyncio.create_task( 592 self.server.start_server_and_accept(sock)) 593 594 # give the server a chance to start listening [...] 595 await asyncio.sleep(0) 596 await self.proto.connect(sock) 597