import socket
import socketserver as ss
import threading
from typing import cast, List, Optional, Tuple
from typing_extensions import Type
import netifaces
from libmuscle.mcp.transport_server import RequestHandler, TransportServer
from libmuscle.mcp.tcp_util import (recv_all, recv_int64, send_int64,
SocketClosed)
[docs]class TcpTransportServerImpl(ss.ThreadingMixIn, ss.TCPServer):
daemon_threads = True
allow_reuse_address = True
def __init__(self, host_port_tuple: Tuple[str, int],
streamhandler: Type, transport_server: 'TcpTransportServer'
) -> None:
super().__init__(host_port_tuple, streamhandler)
self.transport_server = transport_server
if hasattr(socket, "TCP_NODELAY"):
self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, "TCP_QUICKACK"):
self.socket.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
[docs]class TcpHandler(ss.BaseRequestHandler):
"""Handler for MCP-over-TCP connections.
This is a Python handler for Python's TCPServer, which forwards
to the RequestHandler attached to the server.
"""
[docs] def handle(self) -> None:
"""Handles requests on a socket
"""
request = self.receive_request()
while request is not None:
server = cast(TcpTransportServerImpl, self.server).transport_server
response = server._handler.handle_request(request)
send_int64(self.request, len(response))
self.request.sendall(response)
request = self.receive_request()
[docs] def receive_request(self) -> Optional[bytes]:
"""Receives a request
Returns:
The received bytes
"""
try:
length = recv_int64(self.request)
reqbuf = recv_all(self.request, length)
return reqbuf
except SocketClosed:
return None
[docs] def finish(self) -> None:
"""Called when shutting down the thread?"""
server = cast(TcpTransportServerImpl, self.server).transport_server
server._handler.close()
[docs]class TcpTransportServer(TransportServer):
"""A TransportServer that uses TCP to communicate."""
def __init__(self, handler: RequestHandler, port: int = 0) -> None:
"""Create a TCPServer.
Args:
handler: A RequestHandler to handle requests
port: The port to use.
Raises:
OSError: With errno set to errno.EADDRINUSE if the port is not
available.
"""
super().__init__(handler)
self._server = TcpTransportServerImpl(('', port), TcpHandler, self)
self._server_thread = threading.Thread(
target=self._server.serve_forever, args=(0.1,), daemon=True)
self._server_thread.start()
[docs] def get_location(self) -> str:
"""Returns the location this server listens on.
Returns:
A string containing the location.
"""
host, port = self._server.server_address
locs: List[str] = []
for address in self._get_if_addresses():
locs.append('{}:{}'.format(address, port))
return 'tcp:{}'.format(','.join(locs))
[docs] def close(self) -> None:
"""Closes this server.
Stops the server listening, waits for existing clients to
disconnect, then frees any other resources.
"""
self._server.shutdown()
self._server_thread.join()
self._server.server_close()
def _get_if_addresses(self) -> List[str]:
all_addresses: List[str] = []
ifs = netifaces.interfaces()
for interface in ifs:
addrs = netifaces.ifaddresses(interface)
for props in addrs.get(netifaces.AF_INET, []):
if not props['addr'].startswith('127.'):
all_addresses.append(props['addr'])
for props in addrs.get(netifaces.AF_INET6, []):
# filter out link-local addresses with a scope id
if '%' not in props['addr'] and props['addr'] != '::1':
all_addresses.append('[' + props['addr'] + ']')
return all_addresses