1""" 2Generic Asynchronous Message-based Protocol Support 3 4This module provides a generic framework for sending and receiving 5messages over an asyncio stream. `AsyncProtocol` is an abstract class 6that implements the core mechanisms of a simple send/receive protocol, 7and is designed to be extended. 8 9In this package, it is used as the implementation for the `QMPClient` 10class. 11""" 12 13# It's all the docstrings ... ! It's long for a good reason ^_^; 14# pylint: disable=too-many-lines 15 16import asyncio 17from asyncio import StreamReader, StreamWriter 18from contextlib import asynccontextmanager 19from enum import Enum 20from functools import wraps 21import logging 22import socket 23from ssl import SSLContext 24from typing import ( 25 Any, 26 AsyncGenerator, 27 Awaitable, 28 Callable, 29 Generic, 30 List, 31 Optional, 32 Tuple, 33 TypeVar, 34 Union, 35 cast, 36) 37 38from .error import QMPError 39from .util import ( 40 bottom_half, 41 exception_summary, 42 flush, 43 pretty_traceback, 44 upper_half, 45) 46 47 48T = TypeVar('T') 49_U = TypeVar('_U') 50_TaskFN = Callable[[], Awaitable[None]] # aka ``async def func() -> None`` 51 52InternetAddrT = Tuple[str, int] 53UnixAddrT = str 54SocketAddrT = Union[UnixAddrT, InternetAddrT] 55 56 57class Runstate(Enum): 58 """Protocol session runstate.""" 59 60 #: Fully quiesced and disconnected. 61 IDLE = 0 62 #: In the process of connecting or establishing a session. 63 CONNECTING = 1 64 #: Fully connected and active session. 65 RUNNING = 2 66 #: In the process of disconnecting. 67 #: Runstate may be returned to `IDLE` by calling `disconnect()`. 68 DISCONNECTING = 3 69 70 71class ConnectError(QMPError): 72 """ 73 Raised when the initial connection process has failed. 74 75 This Exception always wraps a "root cause" exception that can be 76 interrogated for additional information. 77 78 :param error_message: Human-readable string describing the error. 79 :param exc: The root-cause exception. 80 """ 81 def __init__(self, error_message: str, exc: Exception): 82 super().__init__(error_message) 83 #: Human-readable error string 84 self.error_message: str = error_message 85 #: Wrapped root cause exception 86 self.exc: Exception = exc 87 88 def __str__(self) -> str: 89 cause = str(self.exc) 90 if not cause: 91 # If there's no error string, use the exception name. 92 cause = exception_summary(self.exc) 93 return f"{self.error_message}: {cause}" 94 95 96class StateError(QMPError): 97 """ 98 An API command (connect, execute, etc) was issued at an inappropriate time. 99 100 This error is raised when a command like 101 :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate 102 time. 103 104 :param error_message: Human-readable string describing the state violation. 105 :param state: The actual `Runstate` seen at the time of the violation. 106 :param required: The `Runstate` required to process this command. 107 """ 108 def __init__(self, error_message: str, 109 state: Runstate, required: Runstate): 110 super().__init__(error_message) 111 self.error_message = error_message 112 self.state = state 113 self.required = required 114 115 116F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name 117 118 119# Don't Panic. 120def require(required_state: Runstate) -> Callable[[F], F]: 121 """ 122 Decorator: protect a method so it can only be run in a certain `Runstate`. 123 124 :param required_state: The `Runstate` required to invoke this method. 125 :raise StateError: When the required `Runstate` is not met. 126 """ 127 def _decorator(func: F) -> F: 128 # _decorator is the decorator that is built by calling the 129 # require() decorator factory; e.g.: 130 # 131 # @require(Runstate.IDLE) def foo(): ... 132 # will replace 'foo' with the result of '_decorator(foo)'. 133 134 @wraps(func) 135 def _wrapper(proto: 'AsyncProtocol[Any]', 136 *args: Any, **kwargs: Any) -> Any: 137 # _wrapper is the function that gets executed prior to the 138 # decorated method. 139 140 name = type(proto).__name__ 141 142 if proto.runstate != required_state: 143 if proto.runstate == Runstate.CONNECTING: 144 emsg = f"{name} is currently connecting." 145 elif proto.runstate == Runstate.DISCONNECTING: 146 emsg = (f"{name} is disconnecting." 147 " Call disconnect() to return to IDLE state.") 148 elif proto.runstate == Runstate.RUNNING: 149 emsg = f"{name} is already connected and running." 150 elif proto.runstate == Runstate.IDLE: 151 emsg = f"{name} is disconnected and idle." 152 else: 153 assert False 154 raise StateError(emsg, proto.runstate, required_state) 155 # No StateError, so call the wrapped method. 156 return func(proto, *args, **kwargs) 157 158 # Return the decorated method; 159 # Transforming Func to Decorated[Func]. 160 return cast(F, _wrapper) 161 162 # Return the decorator instance from the decorator factory. Phew! 163 return _decorator 164 165 166class AsyncProtocol(Generic[T]): 167 """ 168 AsyncProtocol implements a generic async message-based protocol. 169 170 This protocol assumes the basic unit of information transfer between 171 client and server is a "message", the details of which are left up 172 to the implementation. It assumes the sending and receiving of these 173 messages is full-duplex and not necessarily correlated; i.e. it 174 supports asynchronous inbound messages. 175 176 It is designed to be extended by a specific protocol which provides 177 the implementations for how to read and send messages. These must be 178 defined in `_do_recv()` and `_do_send()`, respectively. 179 180 Other callbacks have a default implementation, but are intended to be 181 either extended or overridden: 182 183 - `_establish_session`: 184 The base implementation starts the reader/writer tasks. 185 A protocol implementation can override this call, inserting 186 actions to be taken prior to starting the reader/writer tasks 187 before the super() call; actions needing to occur afterwards 188 can be written after the super() call. 189 - `_on_message`: 190 Actions to be performed when a message is received. 191 - `_cb_outbound`: 192 Logging/Filtering hook for all outbound messages. 193 - `_cb_inbound`: 194 Logging/Filtering hook for all inbound messages. 195 This hook runs *before* `_on_message()`. 196 197 :param name: 198 Name used for logging messages, if any. By default, messages 199 will log to 'qemu.qmp.protocol', but each individual connection 200 can be given its own logger by giving it a name; messages will 201 then log to 'qemu.qmp.protocol.${name}'. 202 """ 203 # pylint: disable=too-many-instance-attributes 204 205 #: Logger object for debugging messages from this connection. 206 logger = logging.getLogger(__name__) 207 208 # Maximum allowable size of read buffer 209 _limit = 64 * 1024 210 211 # ------------------------- 212 # Section: Public interface 213 # ------------------------- 214 215 def __init__(self, name: Optional[str] = None) -> None: 216 #: The nickname for this connection, if any. 217 self.name: Optional[str] = name 218 if self.name is not None: 219 self.logger = self.logger.getChild(self.name) 220 221 # stream I/O 222 self._reader: Optional[StreamReader] = None 223 self._writer: Optional[StreamWriter] = None 224 225 # Outbound Message queue 226 self._outgoing: asyncio.Queue[T] 227 228 # Special, long-running tasks: 229 self._reader_task: Optional[asyncio.Future[None]] = None 230 self._writer_task: Optional[asyncio.Future[None]] = None 231 232 # Aggregate of the above two tasks, used for Exception management. 233 self._bh_tasks: Optional[asyncio.Future[Tuple[None, None]]] = None 234 235 #: Disconnect task. The disconnect implementation runs in a task 236 #: so that asynchronous disconnects (initiated by the 237 #: reader/writer) are allowed to wait for the reader/writers to 238 #: exit. 239 self._dc_task: Optional[asyncio.Future[None]] = None 240 241 self._runstate = Runstate.IDLE 242 self._runstate_changed: Optional[asyncio.Event] = None 243 244 # Server state for start_server() and _incoming() 245 self._server: Optional[asyncio.AbstractServer] = None 246 self._accepted: Optional[asyncio.Event] = None 247 248 def __repr__(self) -> str: 249 cls_name = type(self).__name__ 250 tokens = [] 251 if self.name is not None: 252 tokens.append(f"name={self.name!r}") 253 tokens.append(f"runstate={self.runstate.name}") 254 return f"<{cls_name} {' '.join(tokens)}>" 255 256 @property # @upper_half 257 def runstate(self) -> Runstate: 258 """The current `Runstate` of the connection.""" 259 return self._runstate 260 261 @upper_half 262 async def runstate_changed(self) -> Runstate: 263 """ 264 Wait for the `runstate` to change, then return that runstate. 265 """ 266 await self._runstate_event.wait() 267 return self.runstate 268 269 @upper_half 270 @require(Runstate.IDLE) 271 async def start_server_and_accept( 272 self, address: SocketAddrT, 273 ssl: Optional[SSLContext] = None 274 ) -> None: 275 """ 276 Accept a connection and begin processing message queues. 277 278 If this call fails, `runstate` is guaranteed to be set back to `IDLE`. 279 This method is precisely equivalent to calling `start_server()` 280 followed by `accept()`. 281 282 :param address: 283 Address to listen on; UNIX socket path or TCP address/port. 284 :param ssl: SSL context to use, if any. 285 286 :raise StateError: When the `Runstate` is not `IDLE`. 287 :raise ConnectError: 288 When a connection or session cannot be established. 289 290 This exception will wrap a more concrete one. In most cases, 291 the wrapped exception will be `OSError` or `EOFError`. If a 292 protocol-level failure occurs while establishing a new 293 session, the wrapped error may also be an `QMPError`. 294 """ 295 await self.start_server(address, ssl) 296 await self.accept() 297 assert self.runstate == Runstate.RUNNING 298 299 @upper_half 300 @require(Runstate.IDLE) 301 async def start_server(self, address: SocketAddrT, 302 ssl: Optional[SSLContext] = None) -> None: 303 """ 304 Start listening for an incoming connection, but do not wait for a peer. 305 306 This method starts listening for an incoming connection, but 307 does not block waiting for a peer. This call will return 308 immediately after binding and listening on a socket. A later 309 call to `accept()` must be made in order to finalize the 310 incoming connection. 311 312 :param address: 313 Address to listen on; UNIX socket path or TCP address/port. 314 :param ssl: SSL context to use, if any. 315 316 :raise StateError: When the `Runstate` is not `IDLE`. 317 :raise ConnectError: 318 When the server could not start listening on this address. 319 320 This exception will wrap a more concrete one. In most cases, 321 the wrapped exception will be `OSError`. 322 """ 323 async with self._session_guard('Failed to establish connection'): 324 await self._do_start_server(address, ssl) 325 assert self.runstate == Runstate.CONNECTING 326 327 @upper_half 328 @require(Runstate.CONNECTING) 329 async def accept(self) -> None: 330 """ 331 Accept an incoming connection and begin processing message queues. 332 333 If this call fails, `runstate` is guaranteed to be set back to `IDLE`. 334 335 :raise StateError: When the `Runstate` is not `CONNECTING`. 336 :raise QMPError: When `start_server()` was not called yet. 337 :raise ConnectError: 338 When a connection or session cannot be established. 339 340 This exception will wrap a more concrete one. In most cases, 341 the wrapped exception will be `OSError` or `EOFError`. If a 342 protocol-level failure occurs while establishing a new 343 session, the wrapped error may also be an `QMPError`. 344 """ 345 if self._accepted is None: 346 raise QMPError("Cannot call accept() before start_server().") 347 async with self._session_guard('Failed to establish connection'): 348 await self._do_accept() 349 async with self._session_guard('Failed to establish session'): 350 await self._establish_session() 351 assert self.runstate == Runstate.RUNNING 352 353 @upper_half 354 @require(Runstate.IDLE) 355 async def connect(self, address: Union[SocketAddrT, socket.socket], 356 ssl: Optional[SSLContext] = None) -> None: 357 """ 358 Connect to the server and begin processing message queues. 359 360 If this call fails, `runstate` is guaranteed to be set back to `IDLE`. 361 362 :param address: 363 Address to connect to; UNIX socket path or TCP address/port. 364 :param ssl: SSL context to use, if any. 365 366 :raise StateError: When the `Runstate` is not `IDLE`. 367 :raise ConnectError: 368 When a connection or session cannot be established. 369 370 This exception will wrap a more concrete one. In most cases, 371 the wrapped exception will be `OSError` or `EOFError`. If a 372 protocol-level failure occurs while establishing a new 373 session, the wrapped error may also be an `QMPError`. 374 """ 375 async with self._session_guard('Failed to establish connection'): 376 await self._do_connect(address, ssl) 377 async with self._session_guard('Failed to establish session'): 378 await self._establish_session() 379 assert self.runstate == Runstate.RUNNING 380 381 @upper_half 382 async def disconnect(self) -> None: 383 """ 384 Disconnect and wait for all tasks to fully stop. 385 386 If there was an exception that caused the reader/writers to 387 terminate prematurely, it will be raised here. 388 389 :raise Exception: When the reader or writer terminate unexpectedly. 390 """ 391 self.logger.debug("disconnect() called.") 392 self._schedule_disconnect() 393 await self._wait_disconnect() 394 395 # -------------------------- 396 # Section: Session machinery 397 # -------------------------- 398 399 @asynccontextmanager 400 async def _session_guard(self, emsg: str) -> AsyncGenerator[None, None]: 401 """ 402 Async guard function used to roll back to `IDLE` on any error. 403 404 On any Exception, the state machine will be reset back to 405 `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but 406 `BaseException` events will be left alone (This includes 407 asyncio.CancelledError, even prior to Python 3.8). 408 409 :param error_message: 410 Human-readable string describing what connection phase failed. 411 412 :raise BaseException: 413 When `BaseException` occurs in the guarded block. 414 :raise ConnectError: 415 When any other error is encountered in the guarded block. 416 """ 417 try: 418 # Caller's code runs here. 419 yield 420 except BaseException as err: 421 self.logger.error("%s: %s", emsg, exception_summary(err)) 422 self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) 423 try: 424 # Reset the runstate back to IDLE. 425 await self.disconnect() 426 except: 427 # We don't expect any Exceptions from the disconnect function 428 # here, because we failed to connect in the first place. 429 # The disconnect() function is intended to perform 430 # only cannot-fail cleanup here, but you never know. 431 emsg = ( 432 "Unexpected bottom half exception. " 433 "This is a bug in the QMP library. " 434 "Please report it to <qemu-devel@nongnu.org> and " 435 "CC: John Snow <jsnow@redhat.com>." 436 ) 437 self.logger.critical("%s:\n%s\n", emsg, pretty_traceback()) 438 raise 439 440 # CancelledError is an Exception with special semantic meaning; 441 # We do NOT want to wrap it up under ConnectError. 442 # NB: CancelledError is not a BaseException before Python 3.8 443 if isinstance(err, asyncio.CancelledError): 444 raise 445 446 # Any other kind of error can be treated as some kind of connection 447 # failure broadly. Inspect the 'exc' field to explore the root 448 # cause in greater detail. 449 if isinstance(err, Exception): 450 raise ConnectError(emsg, err) from err 451 452 # Raise BaseExceptions un-wrapped, they're more important. 453 raise 454 455 @property 456 def _runstate_event(self) -> asyncio.Event: 457 # asyncio.Event() objects should not be created prior to entrance into 458 # an event loop, so we can ensure we create it in the correct context. 459 # Create it on-demand *only* at the behest of an 'async def' method. 460 if not self._runstate_changed: 461 self._runstate_changed = asyncio.Event() 462 return self._runstate_changed 463 464 @upper_half 465 @bottom_half 466 def _set_state(self, state: Runstate) -> None: 467 """ 468 Change the `Runstate` of the protocol connection. 469 470 Signals the `runstate_changed` event. 471 """ 472 if state == self._runstate: 473 return 474 475 self.logger.debug("Transitioning from '%s' to '%s'.", 476 str(self._runstate), str(state)) 477 self._runstate = state 478 self._runstate_event.set() 479 self._runstate_event.clear() 480 481 @bottom_half 482 async def _stop_server(self) -> None: 483 """ 484 Stop listening for / accepting new incoming connections. 485 """ 486 if self._server is None: 487 return 488 489 try: 490 self.logger.debug("Stopping server.") 491 self._server.close() 492 self.logger.debug("Server stopped.") 493 finally: 494 self._server = None 495 496 @bottom_half # However, it does not run from the R/W tasks. 497 async def _incoming(self, 498 reader: asyncio.StreamReader, 499 writer: asyncio.StreamWriter) -> None: 500 """ 501 Accept an incoming connection and signal the upper_half. 502 503 This method does the minimum necessary to accept a single 504 incoming connection. It signals back to the upper_half ASAP so 505 that any errors during session initialization can occur 506 naturally in the caller's stack. 507 508 :param reader: Incoming `asyncio.StreamReader` 509 :param writer: Incoming `asyncio.StreamWriter` 510 """ 511 peer = writer.get_extra_info('peername', 'Unknown peer') 512 self.logger.debug("Incoming connection from %s", peer) 513 514 if self._reader or self._writer: 515 # Sadly, we can have more than one pending connection 516 # because of https://bugs.python.org/issue46715 517 # Close any extra connections we don't actually want. 518 self.logger.warning("Extraneous connection inadvertently accepted") 519 writer.close() 520 return 521 522 # A connection has been accepted; stop listening for new ones. 523 assert self._accepted is not None 524 await self._stop_server() 525 self._reader, self._writer = (reader, writer) 526 self._accepted.set() 527 528 @upper_half 529 async def _do_start_server(self, address: SocketAddrT, 530 ssl: Optional[SSLContext] = None) -> None: 531 """ 532 Start listening for an incoming connection, but do not wait for a peer. 533 534 This method starts listening for an incoming connection, but does not 535 block waiting for a peer. This call will return immediately after 536 binding and listening to a socket. A later call to accept() must be 537 made in order to finalize the incoming connection. 538 539 :param address: 540 Address to listen on; UNIX socket path or TCP address/port. 541 :param ssl: SSL context to use, if any. 542 543 :raise OSError: For stream-related errors. 544 """ 545 assert self.runstate == Runstate.IDLE 546 self._set_state(Runstate.CONNECTING) 547 548 self.logger.debug("Awaiting connection on %s ...", address) 549 self._accepted = asyncio.Event() 550 551 if isinstance(address, tuple): 552 coro = asyncio.start_server( 553 self._incoming, 554 host=address[0], 555 port=address[1], 556 ssl=ssl, 557 backlog=1, 558 limit=self._limit, 559 ) 560 else: 561 coro = asyncio.start_unix_server( 562 self._incoming, 563 path=address, 564 ssl=ssl, 565 backlog=1, 566 limit=self._limit, 567 ) 568 569 # Allow runstate watchers to witness 'CONNECTING' state; some 570 # failures in the streaming layer are synchronous and will not 571 # otherwise yield. 572 await asyncio.sleep(0) 573 574 # This will start the server (bind(2), listen(2)). It will also 575 # call accept(2) if we yield, but we don't block on that here. 576 self._server = await coro 577 self.logger.debug("Server listening on %s", address) 578 579 @upper_half 580 async def _do_accept(self) -> None: 581 """ 582 Wait for and accept an incoming connection. 583 584 Requires that we have not yet accepted an incoming connection 585 from the upper_half, but it's OK if the server is no longer 586 running because the bottom_half has already accepted the 587 connection. 588 """ 589 assert self._accepted is not None 590 await self._accepted.wait() 591 assert self._server is None 592 self._accepted = None 593 594 self.logger.debug("Connection accepted.") 595 596 @upper_half 597 async def _do_connect(self, address: Union[SocketAddrT, socket.socket], 598 ssl: Optional[SSLContext] = None) -> None: 599 """ 600 Acting as the transport client, initiate a connection to a server. 601 602 :param address: 603 Address to connect to; UNIX socket path or TCP address/port. 604 :param ssl: SSL context to use, if any. 605 606 :raise OSError: For stream-related errors. 607 """ 608 assert self.runstate == Runstate.IDLE 609 self._set_state(Runstate.CONNECTING) 610 611 # Allow runstate watchers to witness 'CONNECTING' state; some 612 # failures in the streaming layer are synchronous and will not 613 # otherwise yield. 614 await asyncio.sleep(0) 615 616 if isinstance(address, socket.socket): 617 self.logger.debug("Connecting with existing socket: " 618 "fd=%d, family=%r, type=%r", 619 address.fileno(), address.family, address.type) 620 connect = asyncio.open_connection( 621 limit=self._limit, 622 ssl=ssl, 623 sock=address, 624 ) 625 elif isinstance(address, tuple): 626 self.logger.debug("Connecting to %s ...", address) 627 connect = asyncio.open_connection( 628 address[0], 629 address[1], 630 ssl=ssl, 631 limit=self._limit, 632 ) 633 else: 634 self.logger.debug("Connecting to file://%s ...", address) 635 connect = asyncio.open_unix_connection( 636 path=address, 637 ssl=ssl, 638 limit=self._limit, 639 ) 640 641 self._reader, self._writer = await connect 642 self.logger.debug("Connected.") 643 644 @upper_half 645 async def _establish_session(self) -> None: 646 """ 647 Establish a new session. 648 649 Starts the readers/writer tasks; subclasses may perform their 650 own negotiations here. The Runstate will be RUNNING upon 651 successful conclusion. 652 """ 653 assert self.runstate == Runstate.CONNECTING 654 655 self._outgoing = asyncio.Queue() 656 657 reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader') 658 writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer') 659 660 self._reader_task = asyncio.create_task(reader_coro) 661 self._writer_task = asyncio.create_task(writer_coro) 662 663 self._bh_tasks = asyncio.gather( 664 self._reader_task, 665 self._writer_task, 666 ) 667 668 self._set_state(Runstate.RUNNING) 669 await asyncio.sleep(0) # Allow runstate_event to process 670 671 @upper_half 672 @bottom_half 673 def _schedule_disconnect(self) -> None: 674 """ 675 Initiate a disconnect; idempotent. 676 677 This method is used both in the upper-half as a direct 678 consequence of `disconnect()`, and in the bottom-half in the 679 case of unhandled exceptions in the reader/writer tasks. 680 681 It can be invoked no matter what the `runstate` is. 682 """ 683 if not self._dc_task: 684 self._set_state(Runstate.DISCONNECTING) 685 self.logger.debug("Scheduling disconnect.") 686 self._dc_task = asyncio.create_task(self._bh_disconnect()) 687 688 @upper_half 689 async def _wait_disconnect(self) -> None: 690 """ 691 Waits for a previously scheduled disconnect to finish. 692 693 This method will gather any bottom half exceptions and re-raise 694 the one that occurred first; presuming it to be the root cause 695 of any subsequent Exceptions. It is intended to be used in the 696 upper half of the call chain. 697 698 :raise Exception: 699 Arbitrary exception re-raised on behalf of the reader/writer. 700 """ 701 assert self.runstate == Runstate.DISCONNECTING 702 assert self._dc_task 703 704 aws: List[Awaitable[object]] = [self._dc_task] 705 if self._bh_tasks: 706 aws.insert(0, self._bh_tasks) 707 all_defined_tasks = asyncio.gather(*aws) 708 709 # Ensure disconnect is done; Exception (if any) is not raised here: 710 await asyncio.wait((self._dc_task,)) 711 712 try: 713 await all_defined_tasks # Raise Exceptions from the bottom half. 714 finally: 715 self._cleanup() 716 self._set_state(Runstate.IDLE) 717 718 @upper_half 719 def _cleanup(self) -> None: 720 """ 721 Fully reset this object to a clean state and return to `IDLE`. 722 """ 723 def _paranoid_task_erase(task: Optional['asyncio.Future[_U]'] 724 ) -> Optional['asyncio.Future[_U]']: 725 # Help to erase a task, ENSURING it is fully quiesced first. 726 assert (task is None) or task.done() 727 return None if (task and task.done()) else task 728 729 assert self.runstate == Runstate.DISCONNECTING 730 self._dc_task = _paranoid_task_erase(self._dc_task) 731 self._reader_task = _paranoid_task_erase(self._reader_task) 732 self._writer_task = _paranoid_task_erase(self._writer_task) 733 self._bh_tasks = _paranoid_task_erase(self._bh_tasks) 734 735 self._reader = None 736 self._writer = None 737 self._accepted = None 738 739 # NB: _runstate_changed cannot be cleared because we still need it to 740 # send the final runstate changed event ...! 741 742 # ---------------------------- 743 # Section: Bottom Half methods 744 # ---------------------------- 745 746 @bottom_half 747 async def _bh_disconnect(self) -> None: 748 """ 749 Disconnect and cancel all outstanding tasks. 750 751 It is designed to be called from its task context, 752 :py:obj:`~AsyncProtocol._dc_task`. By running in its own task, 753 it is free to wait on any pending actions that may still need to 754 occur in either the reader or writer tasks. 755 """ 756 assert self.runstate == Runstate.DISCONNECTING 757 758 def _done(task: Optional['asyncio.Future[Any]']) -> bool: 759 return task is not None and task.done() 760 761 # If the server is running, stop it. 762 await self._stop_server() 763 764 # Are we already in an error pathway? If either of the tasks are 765 # already done, or if we have no tasks but a reader/writer; we 766 # must be. 767 # 768 # NB: We can't use _bh_tasks to check for premature task 769 # completion, because it may not yet have had a chance to run 770 # and gather itself. 771 tasks = tuple(filter(None, (self._writer_task, self._reader_task))) 772 error_pathway = _done(self._reader_task) or _done(self._writer_task) 773 if not tasks: 774 error_pathway |= bool(self._reader) or bool(self._writer) 775 776 try: 777 # Try to flush the writer, if possible. 778 # This *may* cause an error and force us over into the error path. 779 if not error_pathway: 780 await self._bh_flush_writer() 781 except BaseException as err: 782 error_pathway = True 783 emsg = "Failed to flush the writer" 784 self.logger.error("%s: %s", emsg, exception_summary(err)) 785 self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) 786 raise 787 finally: 788 # Cancel any still-running tasks (Won't raise): 789 if self._writer_task is not None and not self._writer_task.done(): 790 self.logger.debug("Cancelling writer task.") 791 self._writer_task.cancel() 792 if self._reader_task is not None and not self._reader_task.done(): 793 self.logger.debug("Cancelling reader task.") 794 self._reader_task.cancel() 795 796 # Close out the tasks entirely (Won't raise): 797 if tasks: 798 self.logger.debug("Waiting for tasks to complete ...") 799 await asyncio.wait(tasks) 800 801 # Lastly, close the stream itself. (*May raise*!): 802 await self._bh_close_stream(error_pathway) 803 self.logger.debug("Disconnected.") 804 805 @bottom_half 806 async def _bh_flush_writer(self) -> None: 807 if not self._writer_task: 808 return 809 810 self.logger.debug("Draining the outbound queue ...") 811 await self._outgoing.join() 812 if self._writer is not None: 813 self.logger.debug("Flushing the StreamWriter ...") 814 await flush(self._writer) 815 816 @bottom_half 817 async def _bh_close_stream(self, error_pathway: bool = False) -> None: 818 # NB: Closing the writer also implicitly closes the reader. 819 if not self._writer: 820 return 821 822 if not self._writer.is_closing(): 823 self.logger.debug("Closing StreamWriter.") 824 self._writer.close() 825 826 self.logger.debug("Waiting for StreamWriter to close ...") 827 try: 828 await self._writer.wait_closed() 829 except Exception: # pylint: disable=broad-except 830 # It's hard to tell if the Stream is already closed or 831 # not. Even if one of the tasks has failed, it may have 832 # failed for a higher-layered protocol reason. The 833 # stream could still be open and perfectly fine. 834 # I don't know how to discern its health here. 835 836 if error_pathway: 837 # We already know that *something* went wrong. Let's 838 # just trust that the Exception we already have is the 839 # better one to present to the user, even if we don't 840 # genuinely *know* the relationship between the two. 841 self.logger.debug( 842 "Discarding Exception from wait_closed:\n%s\n", 843 pretty_traceback(), 844 ) 845 else: 846 # Oops, this is a brand-new error! 847 raise 848 finally: 849 self.logger.debug("StreamWriter closed.") 850 851 @bottom_half 852 async def _bh_loop_forever(self, async_fn: _TaskFN, name: str) -> None: 853 """ 854 Run one of the bottom-half methods in a loop forever. 855 856 If the bottom half ever raises any exception, schedule a 857 disconnect that will terminate the entire loop. 858 859 :param async_fn: The bottom-half method to run in a loop. 860 :param name: The name of this task, used for logging. 861 """ 862 try: 863 while True: 864 await async_fn() 865 except asyncio.CancelledError: 866 # We have been cancelled by _bh_disconnect, exit gracefully. 867 self.logger.debug("Task.%s: cancelled.", name) 868 return 869 except BaseException as err: 870 self.logger.log( 871 logging.INFO if isinstance(err, EOFError) else logging.ERROR, 872 "Task.%s: %s", 873 name, exception_summary(err) 874 ) 875 self.logger.debug("Task.%s: failure:\n%s\n", 876 name, pretty_traceback()) 877 self._schedule_disconnect() 878 raise 879 finally: 880 self.logger.debug("Task.%s: exiting.", name) 881 882 @bottom_half 883 async def _bh_send_message(self) -> None: 884 """ 885 Wait for an outgoing message, then send it. 886 887 Designed to be run in `_bh_loop_forever()`. 888 """ 889 msg = await self._outgoing.get() 890 try: 891 await self._send(msg) 892 finally: 893 self._outgoing.task_done() 894 895 @bottom_half 896 async def _bh_recv_message(self) -> None: 897 """ 898 Wait for an incoming message and call `_on_message` to route it. 899 900 Designed to be run in `_bh_loop_forever()`. 901 """ 902 msg = await self._recv() 903 await self._on_message(msg) 904 905 # -------------------- 906 # Section: Message I/O 907 # -------------------- 908 909 @upper_half 910 @bottom_half 911 def _cb_outbound(self, msg: T) -> T: 912 """ 913 Callback: outbound message hook. 914 915 This is intended for subclasses to be able to add arbitrary 916 hooks to filter or manipulate outgoing messages. The base 917 implementation does nothing but log the message without any 918 manipulation of the message. 919 920 :param msg: raw outbound message 921 :return: final outbound message 922 """ 923 self.logger.debug("--> %s", str(msg)) 924 return msg 925 926 @upper_half 927 @bottom_half 928 def _cb_inbound(self, msg: T) -> T: 929 """ 930 Callback: inbound message hook. 931 932 This is intended for subclasses to be able to add arbitrary 933 hooks to filter or manipulate incoming messages. The base 934 implementation does nothing but log the message without any 935 manipulation of the message. 936 937 This method does not "handle" incoming messages; it is a filter. 938 The actual "endpoint" for incoming messages is `_on_message()`. 939 940 :param msg: raw inbound message 941 :return: processed inbound message 942 """ 943 self.logger.debug("<-- %s", str(msg)) 944 return msg 945 946 @upper_half 947 @bottom_half 948 async def _readline(self) -> bytes: 949 """ 950 Wait for a newline from the incoming reader. 951 952 This method is provided as a convenience for upper-layer 953 protocols, as many are line-based. 954 955 This method *may* return a sequence of bytes without a trailing 956 newline if EOF occurs, but *some* bytes were received. In this 957 case, the next call will raise `EOFError`. It is assumed that 958 the layer 5 protocol will decide if there is anything meaningful 959 to be done with a partial message. 960 961 :raise OSError: For stream-related errors. 962 :raise EOFError: 963 If the reader stream is at EOF and there are no bytes to return. 964 :return: bytes, including the newline. 965 """ 966 assert self._reader is not None 967 msg_bytes = await self._reader.readline() 968 969 if not msg_bytes: 970 if self._reader.at_eof(): 971 raise EOFError 972 973 return msg_bytes 974 975 @upper_half 976 @bottom_half 977 async def _do_recv(self) -> T: 978 """ 979 Abstract: Read from the stream and return a message. 980 981 Very low-level; intended to only be called by `_recv()`. 982 """ 983 raise NotImplementedError 984 985 @upper_half 986 @bottom_half 987 async def _recv(self) -> T: 988 """ 989 Read an arbitrary protocol message. 990 991 .. warning:: 992 This method is intended primarily for `_bh_recv_message()` 993 to use in an asynchronous task loop. Using it outside of 994 this loop will "steal" messages from the normal routing 995 mechanism. It is safe to use prior to `_establish_session()`, 996 but should not be used otherwise. 997 998 This method uses `_do_recv()` to retrieve the raw message, and 999 then transforms it using `_cb_inbound()`. 1000 1001 :return: A single (filtered, processed) protocol message. 1002 """ 1003 message = await self._do_recv() 1004 return self._cb_inbound(message) 1005 1006 @upper_half 1007 @bottom_half 1008 def _do_send(self, msg: T) -> None: 1009 """ 1010 Abstract: Write a message to the stream. 1011 1012 Very low-level; intended to only be called by `_send()`. 1013 """ 1014 raise NotImplementedError 1015 1016 @upper_half 1017 @bottom_half 1018 async def _send(self, msg: T) -> None: 1019 """ 1020 Send an arbitrary protocol message. 1021 1022 This method will transform any outgoing messages according to 1023 `_cb_outbound()`. 1024 1025 .. warning:: 1026 Like `_recv()`, this method is intended to be called by 1027 the writer task loop that processes outgoing 1028 messages. Calling it directly may circumvent logic 1029 implemented by the caller meant to correlate outgoing and 1030 incoming messages. 1031 1032 :raise OSError: For problems with the underlying stream. 1033 """ 1034 msg = self._cb_outbound(msg) 1035 self._do_send(msg) 1036 1037 @bottom_half 1038 async def _on_message(self, msg: T) -> None: 1039 """ 1040 Called to handle the receipt of a new message. 1041 1042 .. caution:: 1043 This is executed from within the reader loop, so be advised 1044 that waiting on either the reader or writer task will lead 1045 to deadlock. Additionally, any unhandled exceptions will 1046 directly cause the loop to halt, so logic may be best-kept 1047 to a minimum if at all possible. 1048 1049 :param msg: The incoming message, already logged/filtered. 1050 """ 1051 # Nothing to do in the abstract case. 1052