import select
import logging
import socket
import time
from typing import Optional, Tuple
from typing_extensions import Buffer
from libmuscle.mcp.transport_client import ProfileData, TransportClient, TimeoutHandler
from libmuscle.mcp.tcp_util import (
is_disconnect, recv_frame, recv_int64, send_frame, send_int64)
from libmuscle.profiling import ProfileTimestamp
from libmuscle.util import Retrier
_logger = logging.getLogger(__name__)
_CONNECT_TIMEOUT = 3.0 # seconds
RECONNECT_TIMEOUT = 60.0 # seconds
[docs]
class NoPendingResponse(RuntimeError):
pass
[docs]
class TcpTransportClient(TransportClient):
"""A client that connects to a TCPTransport server."""
[docs]
@staticmethod
def can_connect_to(location: str) -> bool:
"""Whether this client class can connect to the given location.
Args:
location: The location to potentially connect to.
Returns:
True iff this class can connect to this location.
"""
return location.startswith('tcp:')
def __init__(self, location: str) -> None:
"""Create a TcpClient for a given location.
The client will connect to this location and be able to send requests to it and
return the response.
Args:
location: A location string for the peer.
"""
self._addresses = location[4:].split(',')
self._socket: Optional[socket.SocketType] = None
self._session = 0
self._cur_request = 0
self._reconnect(False)
[docs]
def call(self, request: Buffer, timeout_handler: Optional[TimeoutHandler] = None
) -> Tuple[Buffer, ProfileData]:
"""Send a request to the server and receive the response.
This is a blocking call.
Args:
request: The request to send
timeout_handler: Optional timeout handler. This is used for communication
deadlock detection.
Returns:
The received response
"""
self._cur_request += 1
retrier = Retrier(RECONNECT_TIMEOUT)
deadline = None
did_timeout = False
def handle_timeout() -> None:
nonlocal deadline
nonlocal did_timeout
assert timeout_handler is not None # mypy
assert deadline is not None # mypy
timeout_handler.on_timeout()
deadline += timeout_handler.timeout
did_timeout = True
while True:
try:
if deadline is not None and deadline < time.monotonic():
handle_timeout()
if self._socket is None:
raise ConnectionError('No connection could be established')
start_wait = ProfileTimestamp()
send_int64(self._socket, self._cur_request)
send_frame(self._socket, request)
if timeout_handler is not None:
if deadline is None:
deadline = time.monotonic() + timeout_handler.timeout
while not self._poll(deadline - time.monotonic()):
handle_timeout()
if did_timeout:
timeout_handler.on_receive()
did_timeout = False
start_transfer = ProfileTimestamp()
response = recv_frame(self._socket)
stop_transfer = ProfileTimestamp()
return response, (start_wait, start_transfer, stop_transfer)
except Exception as e:
if is_disconnect(e):
self._handle_disconnect(retrier)
else:
raise
[docs]
def close(self) -> None:
"""Closes this client.
This closes any connections this client has and performs other shutdown
activities as needed.
"""
self._end_session()
self._close_connection()
def _poll(self, timeout: float) -> bool:
"""Poll the socket and return whether its ready for receiving.
This method blocks until the socket is ready for receiving, or :param:`timeout`
seconds have passed (whichever is earlier).
Args:
timeout: timeout in seconds
Returns:
True if the socket is ready for receiving data, False otherwise.
"""
if self._poll_obj is not None:
ready_events = self._poll_obj.poll(timeout * 1000) # poll timeout is in ms
return bool(ready_events)
else:
# Fallback to select()
ready_sockets, _, _ = select.select([self._socket], (), (), timeout)
return bool(ready_sockets)
def _handle_disconnect(self, retrier: Retrier) -> None:
"""Handles a broken network connection.
Args:
retrier: A Retrier that keeps track of timing any retries
"""
_logger.warning(
f'The TCP network connection with {self._addresses} was lost'
' unexpectedly.')
try:
self._close_connection()
except Exception as e:
if not is_disconnect(e):
raise
if retrier.should_give_up():
_logger.warning(
f'I am unable to reconnect to {self._addresses} despite repeated'
' attempts, and I am giving up. Please check your network.')
raise
retrier.sleep()
_logger.warning(f'Trying to reconnect to {self._addresses}')
self._reconnect()
def _reconnect(self, re: bool = True) -> None:
"""(Re)connect to the server and resume the current session
Args:
re: True if this is a reconnect rather than an initial connect.
"""
try:
self._make_connection()
assert self._socket is not None
send_int64(self._socket, self._session)
self._session = recv_int64(self._socket)
if re:
_logger.warning(
f'Reconnected to {self._addresses}, continuing the'
' simulation')
except Exception as e:
if is_disconnect(e):
self._close_connection()
_logger.warning(
f'Failed to reconnect to {self._addresses}, will retry'
' later')
else:
raise
def _make_connection(self) -> None:
"""Connect to the server and set up polling
Uses self._addresses and creates a (new) self._socket and self._poll_obj.
"""
sock: Optional[socket.SocketType] = None
for address in self._addresses:
try:
sock = self._connect(address)
break
except RuntimeError:
pass
if sock is None:
raise ConnectionRefusedError('Failed to connect')
if hasattr(socket, 'TCP_NODELAY'):
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, 'TCP_QUICKACK'):
sock.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
self._socket = sock
if hasattr(select, 'poll'):
self._poll_obj: Optional[select.poll] = select.poll()
self._poll_obj.register(self._socket, select.POLLIN)
else:
self._poll_obj = None # On platforms that don't support select.poll
def _connect(self, address: str) -> socket.SocketType:
loc_parts = address.rsplit(':', 1)
host = loc_parts[0]
if host.startswith('['):
if host.endswith(']'):
host = host[1:-1]
else:
raise RuntimeError('Invalid address')
port = int(loc_parts[1])
addrinfo = socket.getaddrinfo(
host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
for family, socktype, proto, _, sockaddr in addrinfo:
try:
sock = socket.socket(family, socktype, proto)
sock.settimeout(_CONNECT_TIMEOUT)
sock.connect(sockaddr)
sock.settimeout(None)
return sock
except (ConnectionRefusedError, ConnectionAbortedError):
_logger.info(f'Failed to connect to {sockaddr}')
sock.close()
break
except Exception as e:
_logger.debug(f'Failed to connect socket: {e}')
sock.close()
break
raise RuntimeError('Could not connect')
def _end_session(self) -> None:
try:
if self._socket is not None:
send_int64(self._socket, 0)
except Exception as e:
# This can raise if the peer has shut down already when we close our
# connection to it, which is fine and can be ignored. Otherwise, we reraise.
if not is_disconnect(e):
raise
_logger.warning(
'Disconnected while trying to end session, shutdown will take'
' longer than usual because of this.')
def _close_connection(self) -> None:
if self._socket is not None:
try:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except Exception as e:
# This can raise if the peer has shut down already when we close our
# connection to it, which is fine and can be ignored. Otherwise, we
# reraise.
if not is_disconnect(e):
raise