1import asyncio 2from contextlib import contextmanager 3import os 4import socket 5from tempfile import TemporaryDirectory 6 7import avocado 8 9from qemu.aqmp import ConnectError, Runstate 10from qemu.aqmp.protocol import AsyncProtocol, StateError 11from qemu.aqmp.util import asyncio_run, create_task 12 13 14class NullProtocol(AsyncProtocol[None]): 15 """ 16 NullProtocol is a test mockup of an AsyncProtocol implementation. 17 18 It adds a fake_session instance variable that enables a code path 19 that bypasses the actual connection logic, but still allows the 20 reader/writers to start. 21 22 Because the message type is defined as None, an asyncio.Event named 23 'trigger_input' is created that prohibits the reader from 24 incessantly being able to yield None; this event can be poked to 25 simulate an incoming message. 26 27 For testing symmetry with do_recv, an interface is added to "send" a 28 Null message. 29 30 For testing purposes, a "simulate_disconnection" method is also 31 added which allows us to trigger a bottom half disconnect without 32 injecting any real errors into the reader/writer loops; in essence 33 it performs exactly half of what disconnect() normally does. 34 """ 35 def __init__(self, name=None): 36 self.fake_session = False 37 self.trigger_input: asyncio.Event 38 super().__init__(name) 39 40 async def _establish_session(self): 41 self.trigger_input = asyncio.Event() 42 await super()._establish_session() 43 44 async def _do_start_server(self, address, ssl=None): 45 if self.fake_session: 46 self._accepted = asyncio.Event() 47 self._set_state(Runstate.CONNECTING) 48 await asyncio.sleep(0) 49 else: 50 await super()._do_start_server(address, ssl) 51 52 async def _do_accept(self): 53 if self.fake_session: 54 self._accepted = None 55 else: 56 await super()._do_accept() 57 58 async def _do_connect(self, address, ssl=None): 59 if self.fake_session: 60 self._set_state(Runstate.CONNECTING) 61 await asyncio.sleep(0) 62 else: 63 await super()._do_connect(address, ssl) 64 65 async def _do_recv(self) -> None: 66 await self.trigger_input.wait() 67 self.trigger_input.clear() 68 69 def _do_send(self, msg: None) -> None: 70 pass 71 72 async def send_msg(self) -> None: 73 await self._outgoing.put(None) 74 75 async def simulate_disconnect(self) -> None: 76 """ 77 Simulates a bottom-half disconnect. 78 79 This method schedules a disconnection but does not wait for it 80 to complete. This is used to put the loop into the DISCONNECTING 81 state without fully quiescing it back to IDLE. This is normally 82 something you cannot coax AsyncProtocol to do on purpose, but it 83 will be similar to what happens with an unhandled Exception in 84 the reader/writer. 85 86 Under normal circumstances, the library design requires you to 87 await on disconnect(), which awaits the disconnect task and 88 returns bottom half errors as a pre-condition to allowing the 89 loop to return back to IDLE. 90 """ 91 self._schedule_disconnect() 92 93 94class LineProtocol(AsyncProtocol[str]): 95 def __init__(self, name=None): 96 super().__init__(name) 97 self.rx_history = [] 98 99 async def _do_recv(self) -> str: 100 raw = await self._readline() 101 msg = raw.decode() 102 self.rx_history.append(msg) 103 return msg 104 105 def _do_send(self, msg: str) -> None: 106 assert self._writer is not None 107 self._writer.write(msg.encode() + b'\n') 108 109 async def send_msg(self, msg: str) -> None: 110 await self._outgoing.put(msg) 111 112 113def run_as_task(coro, allow_cancellation=False): 114 """ 115 Run a given coroutine as a task. 116 117 Optionally, wrap it in a try..except block that allows this 118 coroutine to be canceled gracefully. 119 """ 120 async def _runner(): 121 try: 122 await coro 123 except asyncio.CancelledError: 124 if allow_cancellation: 125 return 126 raise 127 return create_task(_runner()) 128 129 130@contextmanager 131def jammed_socket(): 132 """ 133 Opens up a random unused TCP port on localhost, then jams it. 134 """ 135 socks = [] 136 137 try: 138 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 139 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 140 sock.bind(('127.0.0.1', 0)) 141 sock.listen(1) 142 address = sock.getsockname() 143 144 socks.append(sock) 145 146 # I don't *fully* understand why, but it takes *two* un-accepted 147 # connections to start jamming the socket. 148 for _ in range(2): 149 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 150 sock.connect(address) 151 socks.append(sock) 152 153 yield address 154 155 finally: 156 for sock in socks: 157 sock.close() 158 159 160class Smoke(avocado.Test): 161 162 def setUp(self): 163 self.proto = NullProtocol() 164 165 def test__repr__(self): 166 self.assertEqual( 167 repr(self.proto), 168 "<NullProtocol runstate=IDLE>" 169 ) 170 171 def testRunstate(self): 172 self.assertEqual( 173 self.proto.runstate, 174 Runstate.IDLE 175 ) 176 177 def testDefaultName(self): 178 self.assertEqual( 179 self.proto.name, 180 None 181 ) 182 183 def testLogger(self): 184 self.assertEqual( 185 self.proto.logger.name, 186 'qemu.aqmp.protocol' 187 ) 188 189 def testName(self): 190 self.proto = NullProtocol('Steve') 191 192 self.assertEqual( 193 self.proto.name, 194 'Steve' 195 ) 196 197 self.assertEqual( 198 self.proto.logger.name, 199 'qemu.aqmp.protocol.Steve' 200 ) 201 202 self.assertEqual( 203 repr(self.proto), 204 "<NullProtocol name='Steve' runstate=IDLE>" 205 ) 206 207 208class TestBase(avocado.Test): 209 210 def setUp(self): 211 self.proto = NullProtocol(type(self).__name__) 212 self.assertEqual(self.proto.runstate, Runstate.IDLE) 213 self.runstate_watcher = None 214 215 def tearDown(self): 216 self.assertEqual(self.proto.runstate, Runstate.IDLE) 217 218 async def _asyncSetUp(self): 219 pass 220 221 async def _asyncTearDown(self): 222 if self.runstate_watcher: 223 await self.runstate_watcher 224 225 @staticmethod 226 def async_test(async_test_method): 227 """ 228 Decorator; adds SetUp and TearDown to async tests. 229 """ 230 async def _wrapper(self, *args, **kwargs): 231 loop = asyncio.get_event_loop() 232 loop.set_debug(True) 233 234 await self._asyncSetUp() 235 await async_test_method(self, *args, **kwargs) 236 await self._asyncTearDown() 237 238 return _wrapper 239 240 # Definitions 241 242 # The states we expect a "bad" connect/accept attempt to transition through 243 BAD_CONNECTION_STATES = ( 244 Runstate.CONNECTING, 245 Runstate.DISCONNECTING, 246 Runstate.IDLE, 247 ) 248 249 # The states we expect a "good" session to transition through 250 GOOD_CONNECTION_STATES = ( 251 Runstate.CONNECTING, 252 Runstate.RUNNING, 253 Runstate.DISCONNECTING, 254 Runstate.IDLE, 255 ) 256 257 # Helpers 258 259 async def _watch_runstates(self, *states): 260 """ 261 This launches a task alongside (most) tests below to confirm that 262 the sequence of runstate changes that occur is exactly as 263 anticipated. 264 """ 265 async def _watcher(): 266 for state in states: 267 new_state = await self.proto.runstate_changed() 268 self.assertEqual( 269 new_state, 270 state, 271 msg=f"Expected state '{state.name}'", 272 ) 273 274 self.runstate_watcher = create_task(_watcher()) 275 # Kick the loop and force the task to block on the event. 276 await asyncio.sleep(0) 277 278 279class State(TestBase): 280 281 @TestBase.async_test 282 async def testSuperfluousDisconnect(self): 283 """ 284 Test calling disconnect() while already disconnected. 285 """ 286 await self._watch_runstates( 287 Runstate.DISCONNECTING, 288 Runstate.IDLE, 289 ) 290 await self.proto.disconnect() 291 292 293class Connect(TestBase): 294 """ 295 Tests primarily related to calling Connect(). 296 """ 297 async def _bad_connection(self, family: str): 298 assert family in ('INET', 'UNIX') 299 300 if family == 'INET': 301 await self.proto.connect(('127.0.0.1', 0)) 302 elif family == 'UNIX': 303 await self.proto.connect('/dev/null') 304 305 async def _hanging_connection(self): 306 with jammed_socket() as addr: 307 await self.proto.connect(addr) 308 309 async def _bad_connection_test(self, family: str): 310 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 311 312 with self.assertRaises(ConnectError) as context: 313 await self._bad_connection(family) 314 315 self.assertIsInstance(context.exception.exc, OSError) 316 self.assertEqual( 317 context.exception.error_message, 318 "Failed to establish connection" 319 ) 320 321 @TestBase.async_test 322 async def testBadINET(self): 323 """ 324 Test an immediately rejected call to an IP target. 325 """ 326 await self._bad_connection_test('INET') 327 328 @TestBase.async_test 329 async def testBadUNIX(self): 330 """ 331 Test an immediately rejected call to a UNIX socket target. 332 """ 333 await self._bad_connection_test('UNIX') 334 335 @TestBase.async_test 336 async def testCancellation(self): 337 """ 338 Test what happens when a connection attempt is aborted. 339 """ 340 # Note that accept() cannot be cancelled outright, as it isn't a task. 341 # However, we can wrap it in a task and cancel *that*. 342 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 343 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 344 345 state = await self.proto.runstate_changed() 346 self.assertEqual(state, Runstate.CONNECTING) 347 348 # This is insider baseball, but the connection attempt has 349 # yielded *just* before the actual connection attempt, so kick 350 # the loop to make sure it's truly wedged. 351 await asyncio.sleep(0) 352 353 task.cancel() 354 await task 355 356 @TestBase.async_test 357 async def testTimeout(self): 358 """ 359 Test what happens when a connection attempt times out. 360 """ 361 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 362 task = run_as_task(self._hanging_connection()) 363 364 # More insider baseball: to improve the speed of this test while 365 # guaranteeing that the connection even gets a chance to start, 366 # verify that the connection hangs *first*, then await the 367 # result of the task with a nearly-zero timeout. 368 369 state = await self.proto.runstate_changed() 370 self.assertEqual(state, Runstate.CONNECTING) 371 await asyncio.sleep(0) 372 373 with self.assertRaises(asyncio.TimeoutError): 374 await asyncio.wait_for(task, timeout=0) 375 376 @TestBase.async_test 377 async def testRequire(self): 378 """ 379 Test what happens when a connection attempt is made while CONNECTING. 380 """ 381 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 382 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 383 384 state = await self.proto.runstate_changed() 385 self.assertEqual(state, Runstate.CONNECTING) 386 387 with self.assertRaises(StateError) as context: 388 await self._bad_connection('UNIX') 389 390 self.assertEqual( 391 context.exception.error_message, 392 "NullProtocol is currently connecting." 393 ) 394 self.assertEqual(context.exception.state, Runstate.CONNECTING) 395 self.assertEqual(context.exception.required, Runstate.IDLE) 396 397 task.cancel() 398 await task 399 400 @TestBase.async_test 401 async def testImplicitRunstateInit(self): 402 """ 403 Test what happens if we do not wait on the runstate event until 404 AFTER a connection is made, i.e., connect()/accept() themselves 405 initialize the runstate event. All of the above tests force the 406 initialization by waiting on the runstate *first*. 407 """ 408 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 409 410 # Kick the loop to coerce the state change 411 await asyncio.sleep(0) 412 assert self.proto.runstate == Runstate.CONNECTING 413 414 # We already missed the transition to CONNECTING 415 await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE) 416 417 task.cancel() 418 await task 419 420 421class Accept(Connect): 422 """ 423 All of the same tests as Connect, but using the accept() interface. 424 """ 425 async def _bad_connection(self, family: str): 426 assert family in ('INET', 'UNIX') 427 428 if family == 'INET': 429 await self.proto.start_server_and_accept(('example.com', 1)) 430 elif family == 'UNIX': 431 await self.proto.start_server_and_accept('/dev/null') 432 433 async def _hanging_connection(self): 434 with TemporaryDirectory(suffix='.aqmp') as tmpdir: 435 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 436 await self.proto.start_server_and_accept(sock) 437 438 439class FakeSession(TestBase): 440 441 def setUp(self): 442 super().setUp() 443 self.proto.fake_session = True 444 445 async def _asyncSetUp(self): 446 await super()._asyncSetUp() 447 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 448 449 async def _asyncTearDown(self): 450 await self.proto.disconnect() 451 await super()._asyncTearDown() 452 453 #### 454 455 @TestBase.async_test 456 async def testFakeConnect(self): 457 458 """Test the full state lifecycle (via connect) with a no-op session.""" 459 await self.proto.connect('/not/a/real/path') 460 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 461 462 @TestBase.async_test 463 async def testFakeAccept(self): 464 """Test the full state lifecycle (via accept) with a no-op session.""" 465 await self.proto.start_server_and_accept('/not/a/real/path') 466 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 467 468 @TestBase.async_test 469 async def testFakeRecv(self): 470 """Test receiving a fake/null message.""" 471 await self.proto.start_server_and_accept('/not/a/real/path') 472 473 logname = self.proto.logger.name 474 with self.assertLogs(logname, level='DEBUG') as context: 475 self.proto.trigger_input.set() 476 self.proto.trigger_input.clear() 477 await asyncio.sleep(0) # Kick reader. 478 479 self.assertEqual( 480 context.output, 481 [f"DEBUG:{logname}:<-- None"], 482 ) 483 484 @TestBase.async_test 485 async def testFakeSend(self): 486 """Test sending a fake/null message.""" 487 await self.proto.start_server_and_accept('/not/a/real/path') 488 489 logname = self.proto.logger.name 490 with self.assertLogs(logname, level='DEBUG') as context: 491 # Cheat: Send a Null message to nobody. 492 await self.proto.send_msg() 493 # Kick writer; awaiting on a queue.put isn't sufficient to yield. 494 await asyncio.sleep(0) 495 496 self.assertEqual( 497 context.output, 498 [f"DEBUG:{logname}:--> None"], 499 ) 500 501 async def _prod_session_api( 502 self, 503 current_state: Runstate, 504 error_message: str, 505 accept: bool = True 506 ): 507 with self.assertRaises(StateError) as context: 508 if accept: 509 await self.proto.start_server_and_accept('/not/a/real/path') 510 else: 511 await self.proto.connect('/not/a/real/path') 512 513 self.assertEqual(context.exception.error_message, error_message) 514 self.assertEqual(context.exception.state, current_state) 515 self.assertEqual(context.exception.required, Runstate.IDLE) 516 517 @TestBase.async_test 518 async def testAcceptRequireRunning(self): 519 """Test that accept() cannot be called when Runstate=RUNNING""" 520 await self.proto.start_server_and_accept('/not/a/real/path') 521 522 await self._prod_session_api( 523 Runstate.RUNNING, 524 "NullProtocol is already connected and running.", 525 accept=True, 526 ) 527 528 @TestBase.async_test 529 async def testConnectRequireRunning(self): 530 """Test that connect() cannot be called when Runstate=RUNNING""" 531 await self.proto.start_server_and_accept('/not/a/real/path') 532 533 await self._prod_session_api( 534 Runstate.RUNNING, 535 "NullProtocol is already connected and running.", 536 accept=False, 537 ) 538 539 @TestBase.async_test 540 async def testAcceptRequireDisconnecting(self): 541 """Test that accept() cannot be called when Runstate=DISCONNECTING""" 542 await self.proto.start_server_and_accept('/not/a/real/path') 543 544 # Cheat: force a disconnect. 545 await self.proto.simulate_disconnect() 546 547 await self._prod_session_api( 548 Runstate.DISCONNECTING, 549 ("NullProtocol is disconnecting." 550 " Call disconnect() to return to IDLE state."), 551 accept=True, 552 ) 553 554 @TestBase.async_test 555 async def testConnectRequireDisconnecting(self): 556 """Test that connect() cannot be called when Runstate=DISCONNECTING""" 557 await self.proto.start_server_and_accept('/not/a/real/path') 558 559 # Cheat: force a disconnect. 560 await self.proto.simulate_disconnect() 561 562 await self._prod_session_api( 563 Runstate.DISCONNECTING, 564 ("NullProtocol is disconnecting." 565 " Call disconnect() to return to IDLE state."), 566 accept=False, 567 ) 568 569 570class SimpleSession(TestBase): 571 572 def setUp(self): 573 super().setUp() 574 self.server = LineProtocol(type(self).__name__ + '-server') 575 576 async def _asyncSetUp(self): 577 await super()._asyncSetUp() 578 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 579 580 async def _asyncTearDown(self): 581 await self.proto.disconnect() 582 try: 583 await self.server.disconnect() 584 except EOFError: 585 pass 586 await super()._asyncTearDown() 587 588 @TestBase.async_test 589 async def testSmoke(self): 590 with TemporaryDirectory(suffix='.aqmp') as tmpdir: 591 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 592 server_task = create_task(self.server.start_server_and_accept(sock)) 593 594 # give the server a chance to start listening [...] 595 await asyncio.sleep(0) 596 await self.proto.connect(sock) 597