import logging
import socket
import socketserver as ss
import threading
from typing import Any, cast
import psutil
from libmuscle.mcp.session_state import SessionState
from libmuscle.mcp.tcp_util import (
is_disconnect,
recv_frame,
recv_int64,
send_frame,
send_int64,
)
from libmuscle.mcp.transport_server import RequestHandler, TransportServer
from libmuscle.util import Retrier
_logger = logging.getLogger(__name__)
[docs]
class TcpTransportServerImpl(ss.ThreadingMixIn, ss.TCPServer):
daemon_threads = True
allow_reuse_address = True
def __init__(self, host_port_tuple: tuple[str, int],
streamhandler: type, transport_server: 'TcpTransportServer'
) -> None:
super().__init__(host_port_tuple, streamhandler)
self.transport_server = transport_server
if hasattr(socket, "TCP_NODELAY"):
self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, "TCP_QUICKACK"):
self.socket.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
self.session_store: dict[int, SessionState] = dict()
self.session_lock = threading.Lock()
self.next_session = 1
[docs]
class TcpHandler(ss.BaseRequestHandler):
"""Handler for MCP-over-TCP connections.
This is a Python handler for Python's TCPServer, which forwards
to the RequestHandler attached to the server.
There's a small terminology issue here: Python calls an entire connection a request,
so self.request actually refers to the current connection we're servicing. We're
doing Remote Procedure Call over that, and we call every RPC call we receive from
the client a request also.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
[docs]
def handle(self) -> None:
"""Handles connections, one per call"""
server = cast(TcpTransportServerImpl, self.server)
try:
session_id = self._start_session()
request_nr = recv_int64(self.request)
while request_nr != 0:
request = recv_frame(self.request)
should_process, should_send = self._session_state.triage_request(
request_nr)
if should_process:
response = server.transport_server._handler.handle_request(request)
self._session_state.set_response(response)
if should_send:
response_to_send = self._session_state.wait_get_response(request_nr)
if response_to_send is not None:
send_frame(self.request, response_to_send)
request_nr = recv_int64(self.request)
self._end_session(session_id)
except Exception as e:
if not is_disconnect(e):
raise
def _start_session(self) -> int:
"""(Re)starts a session
Sessions are identified by a number, which we create and which we and the client
both store. If we get disconnected, the client can reconnect with that session
id, so that we can resend whatever we were sending when we were rudely
interrupted.
Returns:
The id of the new session
"""
req_session_id = recv_int64(self.request)
server = cast(TcpTransportServerImpl, self.server)
if req_session_id == 0:
with server.session_lock:
session_id = server.next_session
server.session_store[session_id] = SessionState()
self._session_state = server.session_store[session_id]
server.next_session += 1
send_int64(self.request, session_id)
else:
_logger.warning(
f'The TCP network connection for session {req_session_id} was lost')
with server.session_lock:
if req_session_id not in server.session_store:
raise RuntimeError(f'Unknown session {req_session_id} requested')
self._session_state = server.session_store[req_session_id]
session_id = req_session_id
send_int64(self.request, session_id)
_logger.warning(f'Resuming session {session_id}')
return session_id
def _end_session(self, session_id: int) -> None:
"""Removes a closed session"""
server = cast(TcpTransportServerImpl, self.server)
with server.session_lock:
del server.session_store[session_id]
[docs]
def finish(self) -> None:
"""Called when shutting down the thread?"""
server = cast(TcpTransportServerImpl, self.server).transport_server
server._handler.close()
[docs]
class TcpTransportServer(TransportServer):
"""A TransportServer that uses TCP to communicate."""
def __init__(self, handler: RequestHandler, port: int = 0) -> None:
"""Create a TCPServer.
Args:
handler: A RequestHandler to handle requests
port: The port to use.
Raises:
OSError: With errno set to errno.EADDRINUSE if the port is not
available.
"""
super().__init__(handler)
self._server = TcpTransportServerImpl(('', port), TcpHandler, self)
self._server_thread = threading.Thread(
target=self._server.serve_forever, args=(0.1,), daemon=True)
self._server_thread.start()
[docs]
def get_location(self) -> str:
"""Returns the location this server listens on.
Returns:
A string containing the location.
"""
# IPv6 may give two more (unneeded) items, so can't unpack directly
port = self._server.server_address[1]
locs: list[str] = []
for address in self._get_if_addresses():
locs.append(f'{address}:{port}')
return 'tcp:{}'.format(','.join(locs))
[docs]
def close(self, graceful: bool = True) -> None:
"""Closes this server.
Waits for all sessions to be closed by the clients, stops the server listening,
waits for existing handlers to close, then frees any other resources.
Args:
graceful: Wait for clients to finish their sessions, where applicable.
"""
if graceful:
retrier = Retrier(60.0, 0.1)
while not retrier.should_give_up():
with self._server.session_lock:
if len(self._server.session_store) == 0:
break
retrier.sleep()
self._server.shutdown()
self._server_thread.join()
self._server.server_close()
def _get_if_addresses(self) -> list[str]:
"""Returns a list of local addresses.
This returns a list of strings containing all IPv4 and IPv6 network
addresses bound to the available network interfaces. The server
will listen on all interfaces, but not all of them may be reachable
from the client. So we get all of them here, and the client can
then try them all and find one that works.
"""
all_addresses: list[str] = []
ifs = psutil.net_if_addrs()
for _, addresses in ifs.items():
for addr in addresses:
if addr.family == socket.AF_INET:
if not addr.address.startswith('127.'):
all_addresses.append(addr.address)
if addr.family == socket.AF_INET6:
# filter out link-local addresses with a scope id
if '%' not in addr.address and addr.address != '::1':
all_addresses.append('[' + addr.address + ']')
return all_addresses