1import asyncio 2from contextlib import contextmanager 3import os 4import socket 5from tempfile import TemporaryDirectory 6 7import avocado 8 9from qemu.aqmp import ConnectError, Runstate 10from qemu.aqmp.protocol import AsyncProtocol, StateError 11from qemu.aqmp.util import asyncio_run, create_task 12 13 14class NullProtocol(AsyncProtocol[None]): 15 """ 16 NullProtocol is a test mockup of an AsyncProtocol implementation. 17 18 It adds a fake_session instance variable that enables a code path 19 that bypasses the actual connection logic, but still allows the 20 reader/writers to start. 21 22 Because the message type is defined as None, an asyncio.Event named 23 'trigger_input' is created that prohibits the reader from 24 incessantly being able to yield None; this event can be poked to 25 simulate an incoming message. 26 27 For testing symmetry with do_recv, an interface is added to "send" a 28 Null message. 29 30 For testing purposes, a "simulate_disconnection" method is also 31 added which allows us to trigger a bottom half disconnect without 32 injecting any real errors into the reader/writer loops; in essence 33 it performs exactly half of what disconnect() normally does. 34 """ 35 def __init__(self, name=None): 36 self.fake_session = False 37 self.trigger_input: asyncio.Event 38 super().__init__(name) 39 40 async def _establish_session(self): 41 self.trigger_input = asyncio.Event() 42 await super()._establish_session() 43 44 async def _do_accept(self, address, ssl=None): 45 if not self.fake_session: 46 await super()._do_accept(address, ssl) 47 48 async def _do_connect(self, address, ssl=None): 49 if not self.fake_session: 50 await super()._do_connect(address, ssl) 51 52 async def _do_recv(self) -> None: 53 await self.trigger_input.wait() 54 self.trigger_input.clear() 55 56 def _do_send(self, msg: None) -> None: 57 pass 58 59 async def send_msg(self) -> None: 60 await self._outgoing.put(None) 61 62 async def simulate_disconnect(self) -> None: 63 """ 64 Simulates a bottom-half disconnect. 65 66 This method schedules a disconnection but does not wait for it 67 to complete. This is used to put the loop into the DISCONNECTING 68 state without fully quiescing it back to IDLE. This is normally 69 something you cannot coax AsyncProtocol to do on purpose, but it 70 will be similar to what happens with an unhandled Exception in 71 the reader/writer. 72 73 Under normal circumstances, the library design requires you to 74 await on disconnect(), which awaits the disconnect task and 75 returns bottom half errors as a pre-condition to allowing the 76 loop to return back to IDLE. 77 """ 78 self._schedule_disconnect() 79 80 81class LineProtocol(AsyncProtocol[str]): 82 def __init__(self, name=None): 83 super().__init__(name) 84 self.rx_history = [] 85 86 async def _do_recv(self) -> str: 87 raw = await self._readline() 88 msg = raw.decode() 89 self.rx_history.append(msg) 90 return msg 91 92 def _do_send(self, msg: str) -> None: 93 assert self._writer is not None 94 self._writer.write(msg.encode() + b'\n') 95 96 async def send_msg(self, msg: str) -> None: 97 await self._outgoing.put(msg) 98 99 100def run_as_task(coro, allow_cancellation=False): 101 """ 102 Run a given coroutine as a task. 103 104 Optionally, wrap it in a try..except block that allows this 105 coroutine to be canceled gracefully. 106 """ 107 async def _runner(): 108 try: 109 await coro 110 except asyncio.CancelledError: 111 if allow_cancellation: 112 return 113 raise 114 return create_task(_runner()) 115 116 117@contextmanager 118def jammed_socket(): 119 """ 120 Opens up a random unused TCP port on localhost, then jams it. 121 """ 122 socks = [] 123 124 try: 125 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 126 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 127 sock.bind(('127.0.0.1', 0)) 128 sock.listen(1) 129 address = sock.getsockname() 130 131 socks.append(sock) 132 133 # I don't *fully* understand why, but it takes *two* un-accepted 134 # connections to start jamming the socket. 135 for _ in range(2): 136 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 137 sock.connect(address) 138 socks.append(sock) 139 140 yield address 141 142 finally: 143 for sock in socks: 144 sock.close() 145 146 147class Smoke(avocado.Test): 148 149 def setUp(self): 150 self.proto = NullProtocol() 151 152 def test__repr__(self): 153 self.assertEqual( 154 repr(self.proto), 155 "<NullProtocol runstate=IDLE>" 156 ) 157 158 def testRunstate(self): 159 self.assertEqual( 160 self.proto.runstate, 161 Runstate.IDLE 162 ) 163 164 def testDefaultName(self): 165 self.assertEqual( 166 self.proto.name, 167 None 168 ) 169 170 def testLogger(self): 171 self.assertEqual( 172 self.proto.logger.name, 173 'qemu.aqmp.protocol' 174 ) 175 176 def testName(self): 177 self.proto = NullProtocol('Steve') 178 179 self.assertEqual( 180 self.proto.name, 181 'Steve' 182 ) 183 184 self.assertEqual( 185 self.proto.logger.name, 186 'qemu.aqmp.protocol.Steve' 187 ) 188 189 self.assertEqual( 190 repr(self.proto), 191 "<NullProtocol name='Steve' runstate=IDLE>" 192 ) 193 194 195class TestBase(avocado.Test): 196 197 def setUp(self): 198 self.proto = NullProtocol(type(self).__name__) 199 self.assertEqual(self.proto.runstate, Runstate.IDLE) 200 self.runstate_watcher = None 201 202 def tearDown(self): 203 self.assertEqual(self.proto.runstate, Runstate.IDLE) 204 205 async def _asyncSetUp(self): 206 pass 207 208 async def _asyncTearDown(self): 209 if self.runstate_watcher: 210 await self.runstate_watcher 211 212 @staticmethod 213 def async_test(async_test_method): 214 """ 215 Decorator; adds SetUp and TearDown to async tests. 216 """ 217 async def _wrapper(self, *args, **kwargs): 218 loop = asyncio.get_event_loop() 219 loop.set_debug(True) 220 221 await self._asyncSetUp() 222 await async_test_method(self, *args, **kwargs) 223 await self._asyncTearDown() 224 225 return _wrapper 226 227 # Definitions 228 229 # The states we expect a "bad" connect/accept attempt to transition through 230 BAD_CONNECTION_STATES = ( 231 Runstate.CONNECTING, 232 Runstate.DISCONNECTING, 233 Runstate.IDLE, 234 ) 235 236 # The states we expect a "good" session to transition through 237 GOOD_CONNECTION_STATES = ( 238 Runstate.CONNECTING, 239 Runstate.RUNNING, 240 Runstate.DISCONNECTING, 241 Runstate.IDLE, 242 ) 243 244 # Helpers 245 246 async def _watch_runstates(self, *states): 247 """ 248 This launches a task alongside (most) tests below to confirm that 249 the sequence of runstate changes that occur is exactly as 250 anticipated. 251 """ 252 async def _watcher(): 253 for state in states: 254 new_state = await self.proto.runstate_changed() 255 self.assertEqual( 256 new_state, 257 state, 258 msg=f"Expected state '{state.name}'", 259 ) 260 261 self.runstate_watcher = create_task(_watcher()) 262 # Kick the loop and force the task to block on the event. 263 await asyncio.sleep(0) 264 265 266class State(TestBase): 267 268 @TestBase.async_test 269 async def testSuperfluousDisconnect(self): 270 """ 271 Test calling disconnect() while already disconnected. 272 """ 273 await self._watch_runstates( 274 Runstate.DISCONNECTING, 275 Runstate.IDLE, 276 ) 277 await self.proto.disconnect() 278 279 280class Connect(TestBase): 281 """ 282 Tests primarily related to calling Connect(). 283 """ 284 async def _bad_connection(self, family: str): 285 assert family in ('INET', 'UNIX') 286 287 if family == 'INET': 288 await self.proto.connect(('127.0.0.1', 0)) 289 elif family == 'UNIX': 290 await self.proto.connect('/dev/null') 291 292 async def _hanging_connection(self): 293 with jammed_socket() as addr: 294 await self.proto.connect(addr) 295 296 async def _bad_connection_test(self, family: str): 297 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 298 299 with self.assertRaises(ConnectError) as context: 300 await self._bad_connection(family) 301 302 self.assertIsInstance(context.exception.exc, OSError) 303 self.assertEqual( 304 context.exception.error_message, 305 "Failed to establish connection" 306 ) 307 308 @TestBase.async_test 309 async def testBadINET(self): 310 """ 311 Test an immediately rejected call to an IP target. 312 """ 313 await self._bad_connection_test('INET') 314 315 @TestBase.async_test 316 async def testBadUNIX(self): 317 """ 318 Test an immediately rejected call to a UNIX socket target. 319 """ 320 await self._bad_connection_test('UNIX') 321 322 @TestBase.async_test 323 async def testCancellation(self): 324 """ 325 Test what happens when a connection attempt is aborted. 326 """ 327 # Note that accept() cannot be cancelled outright, as it isn't a task. 328 # However, we can wrap it in a task and cancel *that*. 329 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 330 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 331 332 state = await self.proto.runstate_changed() 333 self.assertEqual(state, Runstate.CONNECTING) 334 335 # This is insider baseball, but the connection attempt has 336 # yielded *just* before the actual connection attempt, so kick 337 # the loop to make sure it's truly wedged. 338 await asyncio.sleep(0) 339 340 task.cancel() 341 await task 342 343 @TestBase.async_test 344 async def testTimeout(self): 345 """ 346 Test what happens when a connection attempt times out. 347 """ 348 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 349 task = run_as_task(self._hanging_connection()) 350 351 # More insider baseball: to improve the speed of this test while 352 # guaranteeing that the connection even gets a chance to start, 353 # verify that the connection hangs *first*, then await the 354 # result of the task with a nearly-zero timeout. 355 356 state = await self.proto.runstate_changed() 357 self.assertEqual(state, Runstate.CONNECTING) 358 await asyncio.sleep(0) 359 360 with self.assertRaises(asyncio.TimeoutError): 361 await asyncio.wait_for(task, timeout=0) 362 363 @TestBase.async_test 364 async def testRequire(self): 365 """ 366 Test what happens when a connection attempt is made while CONNECTING. 367 """ 368 await self._watch_runstates(*self.BAD_CONNECTION_STATES) 369 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 370 371 state = await self.proto.runstate_changed() 372 self.assertEqual(state, Runstate.CONNECTING) 373 374 with self.assertRaises(StateError) as context: 375 await self._bad_connection('UNIX') 376 377 self.assertEqual( 378 context.exception.error_message, 379 "NullProtocol is currently connecting." 380 ) 381 self.assertEqual(context.exception.state, Runstate.CONNECTING) 382 self.assertEqual(context.exception.required, Runstate.IDLE) 383 384 task.cancel() 385 await task 386 387 @TestBase.async_test 388 async def testImplicitRunstateInit(self): 389 """ 390 Test what happens if we do not wait on the runstate event until 391 AFTER a connection is made, i.e., connect()/accept() themselves 392 initialize the runstate event. All of the above tests force the 393 initialization by waiting on the runstate *first*. 394 """ 395 task = run_as_task(self._hanging_connection(), allow_cancellation=True) 396 397 # Kick the loop to coerce the state change 398 await asyncio.sleep(0) 399 assert self.proto.runstate == Runstate.CONNECTING 400 401 # We already missed the transition to CONNECTING 402 await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE) 403 404 task.cancel() 405 await task 406 407 408class Accept(Connect): 409 """ 410 All of the same tests as Connect, but using the accept() interface. 411 """ 412 async def _bad_connection(self, family: str): 413 assert family in ('INET', 'UNIX') 414 415 if family == 'INET': 416 await self.proto.accept(('example.com', 1)) 417 elif family == 'UNIX': 418 await self.proto.accept('/dev/null') 419 420 async def _hanging_connection(self): 421 with TemporaryDirectory(suffix='.aqmp') as tmpdir: 422 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 423 await self.proto.accept(sock) 424 425 426class FakeSession(TestBase): 427 428 def setUp(self): 429 super().setUp() 430 self.proto.fake_session = True 431 432 async def _asyncSetUp(self): 433 await super()._asyncSetUp() 434 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 435 436 async def _asyncTearDown(self): 437 await self.proto.disconnect() 438 await super()._asyncTearDown() 439 440 #### 441 442 @TestBase.async_test 443 async def testFakeConnect(self): 444 445 """Test the full state lifecycle (via connect) with a no-op session.""" 446 await self.proto.connect('/not/a/real/path') 447 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 448 449 @TestBase.async_test 450 async def testFakeAccept(self): 451 """Test the full state lifecycle (via accept) with a no-op session.""" 452 await self.proto.accept('/not/a/real/path') 453 self.assertEqual(self.proto.runstate, Runstate.RUNNING) 454 455 @TestBase.async_test 456 async def testFakeRecv(self): 457 """Test receiving a fake/null message.""" 458 await self.proto.accept('/not/a/real/path') 459 460 logname = self.proto.logger.name 461 with self.assertLogs(logname, level='DEBUG') as context: 462 self.proto.trigger_input.set() 463 self.proto.trigger_input.clear() 464 await asyncio.sleep(0) # Kick reader. 465 466 self.assertEqual( 467 context.output, 468 [f"DEBUG:{logname}:<-- None"], 469 ) 470 471 @TestBase.async_test 472 async def testFakeSend(self): 473 """Test sending a fake/null message.""" 474 await self.proto.accept('/not/a/real/path') 475 476 logname = self.proto.logger.name 477 with self.assertLogs(logname, level='DEBUG') as context: 478 # Cheat: Send a Null message to nobody. 479 await self.proto.send_msg() 480 # Kick writer; awaiting on a queue.put isn't sufficient to yield. 481 await asyncio.sleep(0) 482 483 self.assertEqual( 484 context.output, 485 [f"DEBUG:{logname}:--> None"], 486 ) 487 488 async def _prod_session_api( 489 self, 490 current_state: Runstate, 491 error_message: str, 492 accept: bool = True 493 ): 494 with self.assertRaises(StateError) as context: 495 if accept: 496 await self.proto.accept('/not/a/real/path') 497 else: 498 await self.proto.connect('/not/a/real/path') 499 500 self.assertEqual(context.exception.error_message, error_message) 501 self.assertEqual(context.exception.state, current_state) 502 self.assertEqual(context.exception.required, Runstate.IDLE) 503 504 @TestBase.async_test 505 async def testAcceptRequireRunning(self): 506 """Test that accept() cannot be called when Runstate=RUNNING""" 507 await self.proto.accept('/not/a/real/path') 508 509 await self._prod_session_api( 510 Runstate.RUNNING, 511 "NullProtocol is already connected and running.", 512 accept=True, 513 ) 514 515 @TestBase.async_test 516 async def testConnectRequireRunning(self): 517 """Test that connect() cannot be called when Runstate=RUNNING""" 518 await self.proto.accept('/not/a/real/path') 519 520 await self._prod_session_api( 521 Runstate.RUNNING, 522 "NullProtocol is already connected and running.", 523 accept=False, 524 ) 525 526 @TestBase.async_test 527 async def testAcceptRequireDisconnecting(self): 528 """Test that accept() cannot be called when Runstate=DISCONNECTING""" 529 await self.proto.accept('/not/a/real/path') 530 531 # Cheat: force a disconnect. 532 await self.proto.simulate_disconnect() 533 534 await self._prod_session_api( 535 Runstate.DISCONNECTING, 536 ("NullProtocol is disconnecting." 537 " Call disconnect() to return to IDLE state."), 538 accept=True, 539 ) 540 541 @TestBase.async_test 542 async def testConnectRequireDisconnecting(self): 543 """Test that connect() cannot be called when Runstate=DISCONNECTING""" 544 await self.proto.accept('/not/a/real/path') 545 546 # Cheat: force a disconnect. 547 await self.proto.simulate_disconnect() 548 549 await self._prod_session_api( 550 Runstate.DISCONNECTING, 551 ("NullProtocol is disconnecting." 552 " Call disconnect() to return to IDLE state."), 553 accept=False, 554 ) 555 556 557class SimpleSession(TestBase): 558 559 def setUp(self): 560 super().setUp() 561 self.server = LineProtocol(type(self).__name__ + '-server') 562 563 async def _asyncSetUp(self): 564 await super()._asyncSetUp() 565 await self._watch_runstates(*self.GOOD_CONNECTION_STATES) 566 567 async def _asyncTearDown(self): 568 await self.proto.disconnect() 569 try: 570 await self.server.disconnect() 571 except EOFError: 572 pass 573 await super()._asyncTearDown() 574 575 @TestBase.async_test 576 async def testSmoke(self): 577 with TemporaryDirectory(suffix='.aqmp') as tmpdir: 578 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") 579 server_task = create_task(self.server.accept(sock)) 580 581 # give the server a chance to start listening [...] 582 await asyncio.sleep(0) 583 await self.proto.connect(sock) 584