Source code for libmuscle.mcp.tcp_transport_client

import logging
import select
import socket
import time
from typing import Optional

from typing_extensions import Buffer

from libmuscle.mcp.tcp_util import (
    is_disconnect,
    recv_frame,
    recv_int64,
    send_frame,
    send_int64,
)
from libmuscle.mcp.transport_client import ProfileData, TimeoutHandler, TransportClient
from libmuscle.profiling import ProfileTimestamp
from libmuscle.util import Retrier

_logger = logging.getLogger(__name__)


_CONNECT_TIMEOUT = 3.0                      # seconds
RECONNECT_TIMEOUT = 60.0                   # seconds


[docs] class NoPendingResponse(RuntimeError): pass
[docs] class TcpTransportClient(TransportClient): """A client that connects to a TCPTransport server."""
[docs] @staticmethod def can_connect_to(location: str) -> bool: """Whether this client class can connect to the given location. Args: location: The location to potentially connect to. Returns: True iff this class can connect to this location. """ return location.startswith('tcp:')
def __init__(self, location: str) -> None: """Create a TcpClient for a given location. The client will connect to this location and be able to send requests to it and return the response. Args: location: A location string for the peer. """ self._addresses = location[4:].split(',') self._socket: Optional[socket.SocketType] = None self._session = 0 self._cur_request = 0 self._reconnect(False)
[docs] def call(self, request: Buffer, timeout_handler: Optional[TimeoutHandler] = None ) -> tuple[Buffer, ProfileData]: """Send a request to the server and receive the response. This is a blocking call. Args: request: The request to send timeout_handler: Optional timeout handler. This is used for communication deadlock detection. Returns: The received response """ self._cur_request += 1 retrier = Retrier(RECONNECT_TIMEOUT) deadline = None did_timeout = False def handle_timeout() -> None: nonlocal deadline nonlocal did_timeout assert timeout_handler is not None # mypy assert deadline is not None # mypy timeout_handler.on_timeout() deadline += timeout_handler.timeout did_timeout = True while True: try: if deadline is not None and deadline < time.monotonic(): handle_timeout() if self._socket is None: raise ConnectionError('No connection could be established') start_wait = ProfileTimestamp() send_int64(self._socket, self._cur_request) send_frame(self._socket, request) if timeout_handler is not None: if deadline is None: deadline = time.monotonic() + timeout_handler.timeout while not self._poll(deadline - time.monotonic()): handle_timeout() if did_timeout: timeout_handler.on_receive() did_timeout = False start_transfer = ProfileTimestamp() response = recv_frame(self._socket) stop_transfer = ProfileTimestamp() return response, (start_wait, start_transfer, stop_transfer) except Exception as e: if is_disconnect(e): self._handle_disconnect(retrier) else: raise
[docs] def close(self) -> None: """Closes this client. This closes any connections this client has and performs other shutdown activities as needed. """ self._end_session() self._close_connection()
def _poll(self, timeout: float) -> bool: """Poll the socket and return whether its ready for receiving. This method blocks until the socket is ready for receiving, or :param:`timeout` seconds have passed (whichever is earlier). Args: timeout: timeout in seconds Returns: True if the socket is ready for receiving data, False otherwise. """ if self._poll_obj is not None: ready_events = self._poll_obj.poll(timeout * 1000) # poll timeout is in ms return bool(ready_events) else: # Fallback to select() ready_sockets, _, _ = select.select([self._socket], (), (), timeout) return bool(ready_sockets) def _handle_disconnect(self, retrier: Retrier) -> None: """Handles a broken network connection. Args: retrier: A Retrier that keeps track of timing any retries """ _logger.warning( f'The TCP network connection with {self._addresses} was lost' ' unexpectedly.') try: self._close_connection() except Exception as e: if not is_disconnect(e): raise if retrier.should_give_up(): _logger.warning( f'I am unable to reconnect to {self._addresses} despite repeated' ' attempts, and I am giving up. Please check your network.') raise retrier.sleep() _logger.warning(f'Trying to reconnect to {self._addresses}') self._reconnect() def _reconnect(self, re: bool = True) -> None: """(Re)connect to the server and resume the current session Args: re: True if this is a reconnect rather than an initial connect. """ try: self._make_connection() assert self._socket is not None send_int64(self._socket, self._session) self._session = recv_int64(self._socket) if re: _logger.warning( f'Reconnected to {self._addresses}, continuing the' ' simulation') except Exception as e: if is_disconnect(e): self._close_connection() _logger.warning( f'Failed to reconnect to {self._addresses}, will retry' ' later') else: raise def _make_connection(self) -> None: """Connect to the server and set up polling Uses self._addresses and creates a (new) self._socket and self._poll_obj. """ sock: Optional[socket.SocketType] = None for address in self._addresses: try: sock = self._connect(address) break except RuntimeError: pass if sock is None: raise ConnectionRefusedError('Failed to connect') if hasattr(socket, 'TCP_NODELAY'): sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) if hasattr(socket, 'TCP_QUICKACK'): sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) self._socket = sock if hasattr(select, 'poll'): self._poll_obj: Optional[select.poll] = select.poll() self._poll_obj.register(self._socket, select.POLLIN) else: self._poll_obj = None # On platforms that don't support select.poll def _connect(self, address: str) -> socket.SocketType: loc_parts = address.rsplit(':', 1) host = loc_parts[0] if host.startswith('['): if host.endswith(']'): host = host[1:-1] else: raise RuntimeError('Invalid address') port = int(loc_parts[1]) addrinfo = socket.getaddrinfo( host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP) for family, socktype, proto, _, sockaddr in addrinfo: try: sock = socket.socket(family, socktype, proto) sock.settimeout(_CONNECT_TIMEOUT) sock.connect(sockaddr) sock.settimeout(None) return sock except (ConnectionRefusedError, ConnectionAbortedError): _logger.info(f'Failed to connect to {sockaddr}') sock.close() break except Exception as e: _logger.debug(f'Failed to connect socket: {e}') sock.close() break raise RuntimeError('Could not connect') def _end_session(self) -> None: try: if self._socket is not None: send_int64(self._socket, 0) except Exception as e: # This can raise if the peer has shut down already when we close our # connection to it, which is fine and can be ignored. Otherwise, we reraise. if not is_disconnect(e): raise _logger.warning( 'Disconnected while trying to end session, shutdown will take' ' longer than usual because of this.') def _close_connection(self) -> None: if self._socket is not None: try: self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() except Exception as e: # This can raise if the peer has shut down already when we close our # connection to it, which is fine and can be ignored. Otherwise, we # reraise. if not is_disconnect(e): raise