import socket
import socketserver as ss
import threading
from typing import cast, List, Optional, Tuple
from typing_extensions import Type
import psutil
from libmuscle.mcp.transport_server import RequestHandler, TransportServer
from libmuscle.mcp.tcp_util import (recv_all, recv_int64, send_int64,
SocketClosed)
[docs]class TcpTransportServerImpl(ss.ThreadingMixIn, ss.TCPServer):
daemon_threads = True
allow_reuse_address = True
def __init__(self, host_port_tuple: Tuple[str, int],
streamhandler: Type, transport_server: 'TcpTransportServer'
) -> None:
super().__init__(host_port_tuple, streamhandler)
self.transport_server = transport_server
if hasattr(socket, "TCP_NODELAY"):
self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, "TCP_QUICKACK"):
self.socket.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
[docs]class TcpHandler(ss.BaseRequestHandler):
"""Handler for MCP-over-TCP connections.
This is a Python handler for Python's TCPServer, which forwards
to the RequestHandler attached to the server.
"""
[docs] def handle(self) -> None:
"""Handles requests on a socket
"""
request = self.receive_request()
while request is not None:
server = cast(TcpTransportServerImpl, self.server).transport_server
response = server._handler.handle_request(request)
send_int64(self.request, len(response))
self.request.sendall(response)
request = self.receive_request()
[docs] def receive_request(self) -> Optional[bytes]:
"""Receives a request
Returns:
The received bytes
"""
try:
length = recv_int64(self.request)
reqbuf = recv_all(self.request, length)
return reqbuf
except SocketClosed:
return None
[docs] def finish(self) -> None:
"""Called when shutting down the thread?"""
server = cast(TcpTransportServerImpl, self.server).transport_server
server._handler.close()
[docs]class TcpTransportServer(TransportServer):
"""A TransportServer that uses TCP to communicate."""
def __init__(self, handler: RequestHandler, port: int = 0) -> None:
"""Create a TCPServer.
Args:
handler: A RequestHandler to handle requests
port: The port to use.
Raises:
OSError: With errno set to errno.EADDRINUSE if the port is not
available.
"""
super().__init__(handler)
self._server = TcpTransportServerImpl(('', port), TcpHandler, self)
self._server_thread = threading.Thread(
target=self._server.serve_forever, args=(0.1,), daemon=True)
self._server_thread.start()
[docs] def get_location(self) -> str:
"""Returns the location this server listens on.
Returns:
A string containing the location.
"""
host, port = self._server.server_address
locs: List[str] = []
for address in self._get_if_addresses():
locs.append('{}:{}'.format(address, port))
return 'tcp:{}'.format(','.join(locs))
[docs] def close(self) -> None:
"""Closes this server.
Stops the server listening, waits for existing clients to
disconnect, then frees any other resources.
"""
self._server.shutdown()
self._server_thread.join()
self._server.server_close()
def _get_if_addresses(self) -> List[str]:
"""Returns a list of local addresses.
This returns a list of strings containing all IPv4 and IPv6 network
addresses bound to the available network interfaces. The server
will listen on all interfaces, but not all of them may be reachable
from the client. So we get all of them here, and the client can
then try them all and find one that works.
"""
all_addresses: List[str] = []
ifs = psutil.net_if_addrs()
for _, addresses in ifs.items():
for addr in addresses:
if addr.family == socket.AF_INET:
if not addr.address.startswith('127.'):
all_addresses.append(addr.address)
if addr.family == socket.AF_INET6:
# filter out link-local addresses with a scope id
if '%' not in addr.address and addr.address != '::1':
all_addresses.append('[' + addr.address + ']')
return all_addresses