1import asyncio 2from contextlib import contextmanager 3import os 4import socket 5from tempfile import TemporaryDirectory 6 7import avocado 8 9from qemu.aqmp import ConnectError, Runstate 10from qemu.aqmp.protocol import AsyncProtocol, StateError 11from qemu.aqmp.util import asyncio_run, create_task 12 13 14class NullProtocol(AsyncProtocol[None]): 15 """ 16 NullProtocol is a test mockup of an AsyncProtocol implementation. 17 18 It adds a fake_session instance variable that enables a code path 19 that bypasses the actual connection logic, but still allows the 20 reader/writers to start. 21 22 Because the message type is defined as None, an asyncio.Event named 23 'trigger_input' is created that prohibits the reader from 24 incessantly being able to yield None; this event can be poked to 25 simulate an incoming message. 26 27 For testing symmetry with do_recv, an interface is added to "send" a 28 Null message. 29 30 For testing purposes, a "simulate_disconnection" method is also 31 added which allows us to trigger a bottom half disconnect without 32 injecting any real errors into the reader/writer loops; in essence 33 it performs exactly half of what disconnect() normally does. 34 """ 35 def __init__(self, name=None): 36 self.fake_session = False 37 self.trigger_input: asyncio.Event 38 super().__init__(name) 39 40 async def _establish_session(self): 41 self.trigger_input = asyncio.Event() 42 await super()._establish_session() 43 44 async def _do_start_server(self, address, ssl=None): 45 if self.fake_session: 46 self._set_state(Runstate.CONNECTING) 47 await asyncio.sleep(0) 48 else: 49 await super()._do_start_server(address, ssl) 50 51 async def _do_connect(self, address, ssl=None): 52 if self.fake_session: 53 self._set_state(Runstate.CONNECTING) 54 await asyncio.sleep(0) 55 else: 56 await super()._do_connect(address, ssl) 57 58 async def _do_recv(self) -> None: 59 await self.trigger_input.wait() 60 self.trigger_input.clear() 61 62 def _do_send(self, msg: None) -> None: 63 pass 64 65 async def send_msg(self) -> None: 66 await self._outgoing.put(None) 67 68 async def simulate_disconnect(self) -> None: 69 """ 70 Simulates a bottom-half disconnect. 71 72 This method schedules a disconnection but does not wait for it 73 to complete. This is used to put the loop into the DISCONNECTING 74 state without fully quiescing it back to IDLE. This is normally 75 something you cannot coax AsyncProtocol to do on purpose, but it 76 will be similar to what happens with an unhandled Exception in 77 the reader/writer. 78 79 Under normal circumstances, the library design requires you to 80 await on disconnect(), which awaits the disconnect task and 81 returns bottom half errors as a pre-condition to allowing the 82 loop to return back to IDLE. 83 """ 84 self._schedule_disconnect() 85 86 87class LineProtocol(AsyncProtocol[str]): 88 def __init__(self, name=None): 89 super().__init__(name) 90 self.rx_history = [] 91 92 async def _do_recv(self) -> str: 93 raw = await self._readline() 94 msg = raw.decode() 95 self.rx_history.append(msg) 96 return msg 97 98 def _do_send(self, msg: str) -> None: 99 assert self._writer is not None 100 self._writer.write(msg.encode() + b'\n') 101 102 async def send_msg(self, msg: str) -> None: 103 await self._outgoing.put(msg) 104 105 106def run_as_task(coro, allow_cancellation=False): 107 """ 108 Run a given coroutine as a task. 109 110 Optionally, wrap it in a try..except block that allows this 111 coroutine to be canceled gracefully. 112 """ 113 async def _runner(): 114 try: 115 await coro 116 except asyncio.CancelledError: 117 if allow_cancellation: 118 return 119 raise 120 return create_task(_runner()) 121 122 123@contextmanager 124def jammed_socket(): 125 """ 126 Opens up a random unused TCP port on localhost, then jams it. 127 """ 128 socks = [] 129 130 try: 131 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 132 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 133 sock.bind(('127.0.0.1', 0)) 134 sock.listen(1) 135 address = sock.getsockname() 136 137 socks.append(sock) 138 139 # I don't *fully* understand why, but it takes *two* un-accepted 140 # connections to start jamming the socket. 141 for _ in range(2): 142 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 143 sock.connect(address) 144 socks.append(sock) 145 146 yield address 147 148 finally: 149 for sock in socks: 150 sock.close() 151 152 153class Smoke(avocado.Test): 154 155 def setUp(self): 156 self.proto = NullProtocol() 157 158 def test__repr__(self): 159 self.assertEqual( 160 repr(self.proto), 161 "<NullProtocol runstate=IDLE>" 162 ) 163 164 def testRunstate(self): 165 self.assertEqual( 166 self.proto.runstate, 167 Runstate.IDLE 168 ) 169 170 def testDefaultName(self): 171 self.assertEqual( 172 self.proto.name, 173 None 174 ) 175 176 def testLogger(self): 177 self.assertEqual( 178 self.proto.logger.name, 179 'qemu.aqmp.protocol' 180 ) 181 182 def testName(self): 183 self.proto = NullProtocol('Steve') 184 185 self.assertEqual( 186 self.proto.name, 187 'Steve' 188 ) 189 190 self.assertEqual( 191 self.proto.logger.name, 192 'qemu.aqmp.protocol.Steve' 193 ) 194 195 self.assertEqual( 196 repr(self.proto), 197 "<NullProtocol name='Steve' runstate=IDLE>" 198 ) 199 200 201class TestBase(avocado.Test): 202 203 def setUp(self): 204 self.proto = NullProtocol(type(self).__name__) 205 self.assertEqual(self.proto.runstate, Runstate.IDLE) 206 self.runstate_watcher = None 207 208 def tearDown(self): 209 self.assertEqual(self.proto.runstate, Runstate.IDLE) 210 211 async def _asyncSetUp(self): 212 pass 213 214 async def _asyncTearDown(self): 215 if self.runstate_watcher: 216 await self.runstate_watcher 217 218 @staticmethod 219 def async_test(async_test_method): 220 """ 221 Decorator; adds SetUp and TearDown to async tests. 222 """ 223 async def _wrapper(self, *args, **kwargs): 224 loop = asyncio.get_event_loop() 225 loop.set_debug(True) 226 227 await self._asyncSetUp() 228 await async_test_method(self, *args, **kwargs) 229 await self._asyncTearDown() 230 231 return _wrapper 232 233 # Definitions 234 235 # The states we expect a "bad" connect/accept attempt to transition through 236 BAD_CONNECTION_STATES = ( 237 Runstate.CONNECTING, 238 Runstate.DISCONNECTING, 239 Runstate.IDLE, 240 ) 241 242 # The states we expect a "good" session to transition through 243 GOOD_CONNECTION_STATES = ( 244 Runstate.CONNECTING, 245 Runstate.RUNNING, 246 Runstate.DISCONNECTING, 247 Runstate.IDLE, 248 ) 249 250 # Helpers 251 252 async def _watch_runstates(self, *states): 253 """ 254 This launches a task alongside (most) tests below to confirm that 255 the sequence of runstate changes that occur is exactly as 256 anticipated. 257 """ 258 async def _watcher(): 259 for state in states: 260 new_state = await self.proto.runstate_changed() 261 self.assertEqual( 262 new_state, 263 state, 264 msg=f"Expected state '{state.name}'", 265 ) 266 267 self.runstate_watcher = create_task(_watcher()) 268 # Kick the loop and force the task to block on the event. 269 await asyncio.sleep(0) 270 271 272class State(TestBase): 273 274 @TestBase.async_test 275 async def testSuperfluousDisconnect(self): 276 """ 277 Test calling disconnect() while already disconnected. 278 """ 279 await self._watch_runstates( 280 Runstate.DISCONNECTING, 281 Runstate.IDLE, 282 ) 283 await self.proto.disconnect() 284 285 286class Connect(TestBase): 287 """ 288 Tests primarily related to calling Connect(). 289 """ 290 async def _bad_connection(self, family: str): 291 assert family in ('INET', 'UNIX') 292 293 if family == 'INET': 294 await self.proto.connect(('127.0.0.1', 0)) 295 elif family == 'UNIX': 296 await self.proto.connect('/dev/null') 297 298 async def _hanging_connection(self): 299 with jammed_socket() as addr: 300 await self.proto.connect(addr) 301 302 async def _bad_connection_test(self, family: str): 303 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 304 305 with self.assertRaises(ConnectError) as context: 306 await self._bad_connection(family) 307 308 self.assertIsInstance(context.exception.exc, OSError) 309 self.assertEqual( 310 context.exception.error_message, 311 "Failed to establish connection" 312 ) 313 314 @TestBase.async_test 315 async def testBadINET(self): 316 """ 317 Test an immediately rejected call to an IP target. 318 """ 319 await self._bad_connection_test('INET') 320 321 @TestBase.async_test 322 async def testBadUNIX(self): 323 """ 324 Test an immediately rejected call to a UNIX socket target. 325 """ 326 await self._bad_connection_test('UNIX') 327 328 @TestBase.async_test 329 async def testCancellation(self): 330 """ 331 Test what happens when a connection attempt is aborted. 332 """ 333 # Note that accept() cannot be cancelled outright, as it isn't a task. 334 # However, we can wrap it in a task and cancel *that*. 335 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 336 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 337 338 state = await self.proto.runstate_changed() 339 self.assertEqual(state, Runstate.CONNECTING) 340 341 # This is insider baseball, but the connection attempt has 342 # yielded *just* before the actual connection attempt, so kick 343 # the loop to make sure it's truly wedged. 344 await asyncio.sleep(0) 345 346 task.cancel() 347 await task 348 349 @TestBase.async_test 350 async def testTimeout(self): 351 """ 352 Test what happens when a connection attempt times out. 353 """ 354 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 355 task = run_as_task(self._hanging_connection()) 356 357 # More insider baseball: to improve the speed of this test while 358 # guaranteeing that the connection even gets a chance to start, 359 # verify that the connection hangs *first*, then await the 360 # result of the task with a nearly-zero timeout. 361 362 state = await self.proto.runstate_changed() 363 self.assertEqual(state, Runstate.CONNECTING) 364 await asyncio.sleep(0) 365 366 with self.assertRaises(asyncio.TimeoutError): 367 await asyncio.wait_for(task, timeout=0) 368 369 @TestBase.async_test 370 async def testRequire(self): 371 """ 372 Test what happens when a connection attempt is made while CONNECTING. 373 """ 374 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 375 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 376 377 state = await self.proto.runstate_changed() 378 self.assertEqual(state, Runstate.CONNECTING) 379 380 with self.assertRaises(StateError) as context: 381 await self._bad_connection('UNIX') 382 383 self.assertEqual( 384 context.exception.error_message, 385 "NullProtocol is currently connecting." 386 ) 387 self.assertEqual(context.exception.state, Runstate.CONNECTING) 388 self.assertEqual(context.exception.required, Runstate.IDLE) 389 390 task.cancel() 391 await task 392 393 @TestBase.async_test 394 async def testImplicitRunstateInit(self): 395 """ 396 Test what happens if we do not wait on the runstate event until 397 AFTER a connection is made, i.e., connect()/accept() themselves 398 initialize the runstate event. All of the above tests force the 399 initialization by waiting on the runstate *first*. 400 """ 401 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 402 403 # Kick the loop to coerce the state change 404 await asyncio.sleep(0) 405 assert self.proto.runstate == Runstate.CONNECTING 406 407 # We already missed the transition to CONNECTING 408 await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE) 409 410 task.cancel() 411 await task 412 413 414class Accept(Connect): 415 """ 416 All of the same tests as Connect, but using the accept() interface. 417 """ 418 async def _bad_connection(self, family: str): 419 assert family in ('INET', 'UNIX') 420 421 if family == 'INET': 422 await self.proto.start_server_and_accept(('example.com', 1)) 423 elif family == 'UNIX': 424 await self.proto.start_server_and_accept('/dev/null') 425 426 async def _hanging_connection(self): 427 with TemporaryDirectory(suffix='.aqmp') as tmpdir: 428 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 429 await self.proto.start_server_and_accept(sock) 430 431 432class FakeSession(TestBase): 433 434 def setUp(self): 435 super().setUp() 436 self.proto.fake_session = True 437 438 async def _asyncSetUp(self): 439 await super()._asyncSetUp() 440 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 441 442 async def _asyncTearDown(self): 443 await self.proto.disconnect() 444 await super()._asyncTearDown() 445 446 #### 447 448 @TestBase.async_test 449 async def testFakeConnect(self): 450 451 """Test the full state lifecycle (via connect) with a no-op session.""" 452 await self.proto.connect('/not/a/real/path') 453 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 454 455 @TestBase.async_test 456 async def testFakeAccept(self): 457 """Test the full state lifecycle (via accept) with a no-op session.""" 458 await self.proto.start_server_and_accept('/not/a/real/path') 459 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 460 461 @TestBase.async_test 462 async def testFakeRecv(self): 463 """Test receiving a fake/null message.""" 464 await self.proto.start_server_and_accept('/not/a/real/path') 465 466 logname = self.proto.logger.name 467 with self.assertLogs(logname, level='DEBUG') as context: 468 self.proto.trigger_input.set() 469 self.proto.trigger_input.clear() 470 await asyncio.sleep(0) # Kick reader. 471 472 self.assertEqual( 473 context.output, 474 [f"DEBUG:{logname}:<-- None"], 475 ) 476 477 @TestBase.async_test 478 async def testFakeSend(self): 479 """Test sending a fake/null message.""" 480 await self.proto.start_server_and_accept('/not/a/real/path') 481 482 logname = self.proto.logger.name 483 with self.assertLogs(logname, level='DEBUG') as context: 484 # Cheat: Send a Null message to nobody. 485 await self.proto.send_msg() 486 # Kick writer; awaiting on a queue.put isn't sufficient to yield. 487 await asyncio.sleep(0) 488 489 self.assertEqual( 490 context.output, 491 [f"DEBUG:{logname}:--> None"], 492 ) 493 494 async def _prod_session_api( 495 self, 496 current_state: Runstate, 497 error_message: str, 498 accept: bool = True 499 ): 500 with self.assertRaises(StateError) as context: 501 if accept: 502 await self.proto.start_server_and_accept('/not/a/real/path') 503 else: 504 await self.proto.connect('/not/a/real/path') 505 506 self.assertEqual(context.exception.error_message, error_message) 507 self.assertEqual(context.exception.state, current_state) 508 self.assertEqual(context.exception.required, Runstate.IDLE) 509 510 @TestBase.async_test 511 async def testAcceptRequireRunning(self): 512 """Test that accept() cannot be called when Runstate=RUNNING""" 513 await self.proto.start_server_and_accept('/not/a/real/path') 514 515 await self._prod_session_api( 516 Runstate.RUNNING, 517 "NullProtocol is already connected and running.", 518 accept=True, 519 ) 520 521 @TestBase.async_test 522 async def testConnectRequireRunning(self): 523 """Test that connect() cannot be called when Runstate=RUNNING""" 524 await self.proto.start_server_and_accept('/not/a/real/path') 525 526 await self._prod_session_api( 527 Runstate.RUNNING, 528 "NullProtocol is already connected and running.", 529 accept=False, 530 ) 531 532 @TestBase.async_test 533 async def testAcceptRequireDisconnecting(self): 534 """Test that accept() cannot be called when Runstate=DISCONNECTING""" 535 await self.proto.start_server_and_accept('/not/a/real/path') 536 537 # Cheat: force a disconnect. 538 await self.proto.simulate_disconnect() 539 540 await self._prod_session_api( 541 Runstate.DISCONNECTING, 542 ("NullProtocol is disconnecting." 543 " Call disconnect() to return to IDLE state."), 544 accept=True, 545 ) 546 547 @TestBase.async_test 548 async def testConnectRequireDisconnecting(self): 549 """Test that connect() cannot be called when Runstate=DISCONNECTING""" 550 await self.proto.start_server_and_accept('/not/a/real/path') 551 552 # Cheat: force a disconnect. 553 await self.proto.simulate_disconnect() 554 555 await self._prod_session_api( 556 Runstate.DISCONNECTING, 557 ("NullProtocol is disconnecting." 558 " Call disconnect() to return to IDLE state."), 559 accept=False, 560 ) 561 562 563class SimpleSession(TestBase): 564 565 def setUp(self): 566 super().setUp() 567 self.server = LineProtocol(type(self).__name__ + '-server') 568 569 async def _asyncSetUp(self): 570 await super()._asyncSetUp() 571 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 572 573 async def _asyncTearDown(self): 574 await self.proto.disconnect() 575 try: 576 await self.server.disconnect() 577 except EOFError: 578 pass 579 await super()._asyncTearDown() 580 581 @TestBase.async_test 582 async def testSmoke(self): 583 with TemporaryDirectory(suffix='.aqmp') as tmpdir: 584 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 585 server_task = create_task(self.server.start_server_and_accept(sock)) 586 587 # give the server a chance to start listening [...] 588 await asyncio.sleep(0) 589 await self.proto.connect(sock) 590