from dataclasses import dataclass, field
from datetime import datetime
from enum import Flag, auto
from functools import lru_cache
from itertools import chain, zip_longest
from operator import attrgetter
from pathlib import Path
from queue import Queue
from threading import Thread
from typing import Dict, Optional, Set, FrozenSet, List, Tuple, TypeVar
from ymmsl import (
Reference, Model, Identifier, Implementation, save,
PartialConfiguration)
from libmuscle.manager.topology_store import TopologyStore
from libmuscle.snapshot import SnapshotMetadata
_MAX_FILE_EXISTS_CHECK = 100
_SnapshotDictType = Dict[Reference, List["SnapshotNode"]]
_ConnectionType = Tuple[Identifier, Identifier, "_ConnectionInfo"]
_QueueItemType = Optional[Tuple[Reference, SnapshotMetadata]]
_T = TypeVar("_T")
# this snapshot is used as a placeholder for restarting from scratch
_NULL_SNAPSHOT = SnapshotMetadata(["Instance start"], 0, 0, None, {}, True, '')
[docs]def safe_get(lst: List[_T], index: int, default: _T) -> _T:
"""Get an item from the list, returning default when it does not exist.
Args:
lst: List to get the item from
index: Which item to get, should be >= 0
default: Value to return when hitting an IndexError
"""
try:
return lst[index]
except IndexError:
return default
class _ConnectionInfo(Flag):
SELF_IS_SENDING = auto()
SELF_IS_VECTOR = auto()
PEER_IS_VECTOR = auto()
[docs]def calc_consistency(
num1: int, num2: int, first_is_sent: bool, num2_is_restart: bool
) -> bool:
"""Calculate consistency of message counts.
Args:
num1: message count of instance 1
num2: message count of instance 2
first_is_sent: True iff instance 1 is sending messages over this conduit
num2_is_restart: True iff the snapshot of num2 is a full restart
Returns:
True iff the two message counts are consistent
"""
return (num1 == num2 or # strong
num1 + 1 == num2 and first_is_sent or # weak (1 = sent)
# weak (2 = sent) - only allow if num2 is not a restart
num2 + 1 == num1 and not first_is_sent and not num2_is_restart)
[docs]def calc_consistency_list(
num1: List[int], num2: List[int], first_is_sent: bool,
num2_is_restart: bool) -> bool:
"""Calculate consistency of message counts.
Args:
num1: message count of instance 1
num2: message count of instance 2
first_is_sent: True iff instance 1 is sending messages over this conduit
num2_is_restart: True iff the snapshot of num2 is a full restart
Returns:
True iff the two message counts are consistent
"""
if first_is_sent:
allow_weak = True
slot_iter = zip_longest(num1, num2, fillvalue=0)
else:
allow_weak = not num2_is_restart
slot_iter = zip_longest(num2, num1, fillvalue=0)
return all(slot_sent == slot_received or # strong
slot_sent + 1 == slot_received and allow_weak # weak
for slot_sent, slot_received in slot_iter)
[docs]@dataclass
class SnapshotNode:
"""Represents a node in the snapshot graph.
Attributes:
num: The number of the snapshot. Unique for this instance. Later
snapshots always have a higher num.
instance: Which instance this is a snapshot of.
snapshot: The snapshot metadata reported by the instance.
peers: The set of peers that the instance is connected to.
consistent_peers: Keeps track of snapshots per peer that are consistent
with this one.
"""
num: int
instance: Reference
snapshot: SnapshotMetadata
peers: FrozenSet[Reference]
consistent_peers: Dict[Reference, List["SnapshotNode"]] = field(
default_factory=dict, repr=False)
def __hash__(self) -> int:
return object.__hash__(self)
@property
def consistent(self) -> bool:
"""Returns True iff there is a consistent checkpoint with all peers.
"""
return self.consistent_peers.keys() == self.peers
[docs] def do_consistency_check(
self,
peer_node: "SnapshotNode",
connections: List[_ConnectionType]) -> bool:
"""Check if the snapshot of the peer is consistent with us.
When the peer snapshot is consistent, adds it to our list of consistent
peer snapshots (in :attr:`consistent_peers`) and vice versa.
Args:
peer_node: Snapshot of one of our peers
connections: All connections from our instance to the peer instance
Returns:
True iff the peer snapshot is consistent with ours.
"""
i_snapshot = self.snapshot
p_snapshot = peer_node.snapshot
peer_is_restart = p_snapshot is _NULL_SNAPSHOT
for connection in connections:
i_port, p_port, conn = connection
is_sending = bool(conn & _ConnectionInfo.SELF_IS_SENDING)
i_msg_counts = i_snapshot.port_message_counts.get(str(i_port), [])
p_msg_counts = p_snapshot.port_message_counts.get(str(p_port), [])
if conn & _ConnectionInfo.SELF_IS_VECTOR:
slot = int(peer_node.instance[-1])
consistent = calc_consistency(
safe_get(i_msg_counts, slot, 0),
safe_get(p_msg_counts, 0, 0),
is_sending, peer_is_restart)
elif conn & _ConnectionInfo.PEER_IS_VECTOR:
slot = int(self.instance[-1])
consistent = calc_consistency(
safe_get(i_msg_counts, 0, 0),
safe_get(p_msg_counts, slot, 0),
is_sending, peer_is_restart)
else:
consistent = calc_consistency_list(
i_msg_counts, p_msg_counts, is_sending, peer_is_restart)
if not consistent:
return False
self.consistent_peers.setdefault(
peer_node.instance, []).append(peer_node)
peer_node.consistent_peers.setdefault(
self.instance, []).append(self)
return True
[docs]class SnapshotRegistry(Thread):
"""Registry of all snapshots taken by instances.
Current snapshots are stored in a graph. Every node represents a snapshot
taken by an instance (see :class:`SnapshotNode`). When snapshots from peer
instances are consistent, the nodes are connected to each other.
This class manages the snapshot nodes. New snapshots are registered through
:meth:`register_snapshot`.
"""
def __init__(
self, config: PartialConfiguration, snapshot_folder: Path,
topology_store: TopologyStore) -> None:
"""Create a snapshot graph using provided configuration.
Args:
config: ymmsl configuration describing the workflow.
"""
super().__init__(name='SnapshotRegistry')
if config.model is None or not isinstance(config.model, Model):
raise ValueError('The yMMSL experiment description does not'
' contain a (complete) model section, so there'
' is nothing to run!')
self._configuration = config
self._model = config.model
self._snapshot_folder = snapshot_folder
self._topology_store = topology_store
self._queue: Queue[_QueueItemType] = Queue()
self._snapshots: _SnapshotDictType = {}
self._instances: Set[Reference] = set()
for component in config.model.components:
self._instances.update(component.instances())
# Create snapshot nodes for starting from scratch
for instance in self._instances:
self.register_snapshot(instance, _NULL_SNAPSHOT)
[docs] def register_snapshot(
self, instance: Reference, snapshot: SnapshotMetadata) -> None:
"""Register a new snapshot.
Args:
instance: The instance that created the snapshot
snapshot: Metadata describing the snapshot
"""
self._queue.put((instance, snapshot))
[docs] def run(self) -> None:
"""Code executed in a separate thread
"""
while True:
item = self._queue.get()
if item is None:
return
self._add_snapshot(*item)
[docs] def shutdown(self) -> None:
"""Stop the snapshot registry thread
"""
self._queue.put(None)
def _add_snapshot(
self, instance: Reference, snapshot: SnapshotMetadata) -> None:
"""Register a new snapshot.
Args:
instance: The instance that created the snapshot
snapshot: Metadata describing the snapshot
"""
stateful_peers = self._get_peers(instance)
i_snapshots = self._snapshots.setdefault(instance, [])
# get next number of the snapshot
num = 1 if not i_snapshots else i_snapshots[-1].num + 1
snapshotnode = SnapshotNode(num, instance, snapshot, stateful_peers)
i_snapshots.append(snapshotnode)
# check consistency with all peers
for peer in stateful_peers:
for peer_snapshot in self._snapshots.get(peer, []):
snapshotnode.do_consistency_check(
peer_snapshot, self._get_connections(instance, peer))
# finally, check if this snapshotnode is now part of a workflow snapshot
if snapshot is not _NULL_SNAPSHOT:
self._save_workflow_snapshot(snapshotnode)
def _save_workflow_snapshot(self, snapshotnode: SnapshotNode) -> None:
"""Save snapshot if a workflow snapshot exists with the provided node.
Args:
snapshotnode: The snapshot node that must be part of the workflow
snapshot.
"""
workflow_snapshots = self._get_workflow_snapshots(snapshotnode)
for workflow_snapshot in workflow_snapshots:
self._write_snapshot_ymmsl(workflow_snapshot)
self._cleanup_snapshots(workflow_snapshots)
def _get_workflow_snapshots(
self, snapshot: SnapshotNode) -> List[List[SnapshotNode]]:
"""Return all workflow snapshots which contain the provided node.
Args:
snapshotnode: The snapshot node that must be part of the workflow
snapshot.
Returns:
List of workflow snapshots. Each workflow snapshot is a list of
instance snapshot nodes.
"""
if not snapshot.consistent:
return []
# Instances that don't have a snapshot node chosen yet:
instances_to_cover = list(self._instances - {snapshot.instance})
# Allowed snapshots per instance. This is updated during the heuristic
# to further restrict the sets of snapshots as peer snapshots are
# selected.
# First restriction is that the snapshots have to be locally consistent.
allowed_snapshots: Dict[Reference, FrozenSet[SnapshotNode]] = {}
for instance in instances_to_cover:
allowed_snapshots[instance] = frozenset(
i_snapshot
for i_snapshot in self._snapshots.get(instance, [])
if i_snapshot.consistent)
if not allowed_snapshots[instance]:
# there cannot be a workflow snapshot if this instance has no
# consistent snapshot nodes
return []
instance = snapshot.instance
allowed_snapshots[instance] = frozenset({snapshot})
def num_allowed_snapshots(instance: Reference) -> int:
"""Get number of allowed snapshots at this point for this instance.
The allowed snapshots are those that are consistent with all
selected snapshots at this point in the heuristic.
"""
return len(allowed_snapshots[instance])
# Do a full, depth-first search for all workflow snapshots
# ========================================================
workflow_snapshots = []
selected_snapshots = [snapshot]
# This stack stores history of allowed_snapshots and enables roll back
stack: List[Dict[Reference, FrozenSet[SnapshotNode]]] = []
# Update allowed_snapshots for peers of the selected snapshot
for peer, snapshots in snapshot.consistent_peers.items():
intersection = allowed_snapshots[peer].intersection(snapshots)
if not intersection:
return []
allowed_snapshots[peer] = intersection
while True:
# 1. Select most constrained instance
#
# Note: we're only interested in the instance with the least allowed
# snapshots. Better performance may be possible by not doing a full
# sort, but it should be tested. Expectation is that
# instances_to_cover remains mostly sorted (as the only counts that
# are changing are for peers of the previous selected instance).
# Python's sort algorithm is O(N) when the list is already sorted
# (which is the same as max()).
#
# We cannot use a priority queue (heapq) because
# num_allowed_snapshots is changing every iteration.
instances_to_cover.sort(key=num_allowed_snapshots, reverse=True)
instance = instances_to_cover.pop()
# 2. Select the oldest snapshot of this instance
snapshot = min(allowed_snapshots[instance], key=attrgetter('num'))
selected_snapshots.append(snapshot)
# A shallow copy is ok: the values are immutable frozensets
stack.append(allowed_snapshots.copy())
# 3. Update allowed snapshots based on the newly selected
allowed_snapshots[instance] = frozenset({snapshot})
for peer, snapshots in snapshot.consistent_peers.items():
intersection = allowed_snapshots[peer].intersection(snapshots)
if not intersection:
break # roll back
allowed_snapshots[peer] = intersection
else:
# 4. Selected snapshot is okay to explore further
if instances_to_cover:
# 4a. There are still instance to cover, return to the start
# of the while loop.
continue
# 4b. We have found a complete workflow snapshot
workflow_snapshots.append(selected_snapshots.copy())
# Next: perform a roll-back to continue the search
# 5. Roll back
# stop when selected_snapshots only contains the one we forced to be
# part of the workflow snapshot
while len(selected_snapshots) > 1:
snapshot = selected_snapshots.pop()
instance = snapshot.instance
instances_to_cover.append(instance)
allowed_snapshots = stack.pop()
intersection = allowed_snapshots[instance] - {snapshot}
allowed_snapshots[instance] = intersection
if intersection:
# We have a valid next snapshot to try for this instance
break
# No allowed_snapshots, try another roll back
else:
# Exhausted all roll back possibilities, so we are done now
return workflow_snapshots
def _write_snapshot_ymmsl(
self, selected_snapshots: List[SnapshotNode]) -> None:
"""Write the snapshot ymmsl file to the snapshot folder.
Args:
selected_snapshots: All snapshot nodes of the workflow snapshot.
"""
now = datetime.now()
config = self._generate_snapshot_config(selected_snapshots, now)
time = now.strftime('%Y%m%d_%H%M%S')
for i in range(_MAX_FILE_EXISTS_CHECK):
if i:
snapshot_filename = f'snapshot_{time}_{i}.ymmsl'
else:
snapshot_filename = f'snapshot_{time}.ymmsl'
savepath = self._snapshot_folder / snapshot_filename
if not savepath.exists():
save(config, savepath)
return
raise RuntimeError('Could not find an available filename for storing'
f' the next workflow snapshot: {savepath} already'
' exists.')
def _generate_snapshot_config(
self, selected_snapshots: List[SnapshotNode], now: datetime
) -> PartialConfiguration:
"""Generate ymmsl configuration for snapshot file
"""
selected_snapshots.sort(key=attrgetter('instance'))
resume = {}
for node in selected_snapshots:
if node.snapshot is not _NULL_SNAPSHOT:
# Only store resume information when it is an actual snapshot
# created by the instance. Otherwise the instance can just be
# restarted from the beginning.
resume[node.instance] = Path(node.snapshot.snapshot_filename)
description = self._generate_description(selected_snapshots, now)
return PartialConfiguration(resume=resume, description=description)
def _generate_description(
self, selected_snapshots: List[SnapshotNode], now: datetime) -> str:
"""Generate a human-readable description of the workflow snapshot.
"""
triggers: Dict[str, List[str]] = {}
component_info = []
max_instance_len = len('Instance ')
for node in selected_snapshots:
for trigger in node.snapshot.triggers:
triggers.setdefault(trigger, []).append(str(node.instance))
component_info.append((
str(node.instance),
f'{node.snapshot.timestamp:<11.6g}',
f'{node.snapshot.wallclock_time:<11.6g}',
("Intermediate", "Final")[node.snapshot.is_final_snapshot]))
max_instance_len = max(max_instance_len, len(str(node.instance)))
instance_with_padding = 'Instance'.ljust(max_instance_len)
component_table = [
f'{instance_with_padding} t Wallclock time Type',
f'{"-" * (max_instance_len + 41)}']
component_table += [
f'{name.ljust(max_instance_len)} {timestamp} {walltime}'
f' {typ}'
for name, timestamp, walltime, typ in component_info]
return (f'Workflow snapshot for {self._model.name}'
f' taken on {now.strftime("%Y-%m-%d %H:%M:%S")}.\n'
'Snapshot triggers:\n' +
'\n'.join(f'- {trigger} ({", ".join(triggers[trigger])})'
for trigger in sorted(triggers)) +
'\n\n' +
'\n'.join(component_table) + '\n')
def _cleanup_snapshots(
self, workflow_snapshots: List[List[SnapshotNode]]) -> None:
"""Remove all snapshots that are older than the selected snapshots.
Args:
selected_snapshots: All snapshot nodes of a workflow snapshot
"""
if not workflow_snapshots:
return
# Find the newest snapshots per instance
newest_snapshots = {snapshot.instance: snapshot
for snapshot in workflow_snapshots[0]}
for workflow_snapshot in workflow_snapshots[1:]:
for snapshot in workflow_snapshot:
if newest_snapshots[snapshot.instance].num < snapshot.num:
newest_snapshots[snapshot.instance] = snapshot
# Remove all snapshots that are older than the newest snapshots
removed_snapshots: Set[SnapshotNode] = set()
for snapshot in newest_snapshots.values():
all_snapshots = self._snapshots[snapshot.instance]
idx = all_snapshots.index(snapshot)
self._snapshots[snapshot.instance] = all_snapshots[idx:]
removed_snapshots.update(all_snapshots[:idx])
# Remove all references in SnapshotNode.peer_snapshot to the snapshots
# that are cleaned up
for snapshot in removed_snapshots:
for peer_snapshot in chain.from_iterable(
snapshot.consistent_peers.values()):
if peer_snapshot in removed_snapshots:
# snapshot is removed anyway, no need to update references
continue
# peer_snapshot is still there, remove reference to us
peer_snapshot.consistent_peers[snapshot.instance].remove(
snapshot)
@lru_cache(maxsize=None)
def _get_peers(self, instance: Reference) -> FrozenSet[Reference]:
"""Return the set of peers for the given instance.
Note: instance is assumed to contain the full index, not just the
component name.
Args:
instance: Instance to get peers of.
Returns:
Frozen set with all peer instances (including their index).
"""
return frozenset(self._topology_store.get_peer_instances(instance))
@lru_cache(maxsize=None)
def _get_connections(self, instance: Reference, peer: Reference
) -> List[_ConnectionType]:
"""Get the list of connections between instance and peer.
Args:
instance: Instance reference (including index)
peer: Peer reference (including index)
Returns:
A list of tuples describing all conduits between instance and peer:
instance_port (Reference): the port of instance that is
connected to
peer_port (Reference): the port on the peer instance
info (_ConnectionInfo): flag describing the connection. The
instance is sending when
``info & _ConnectionInfo.SELF_IS_SENDING`` and receiving
otherwise. When the instance port is a vector port and the
peer port is a non-vector port, the flag
``_ConnectionInfo.SELF_IS_VECTOR`` is set. In the reverse
situation the flag ``_ConnectionInfo.PEER_IS_VECTOR`` is
set. When both ports are vector or non-vector, neither flag
is set.
"""
instance_kernel = instance.without_trailing_ints()
peer_kernel = peer.without_trailing_ints()
connected_ports: List[_ConnectionType] = []
for conduit in self._model.conduits:
if (conduit.sending_component() == instance_kernel and
conduit.receiving_component() == peer_kernel):
conn_type = _ConnectionInfo.SELF_IS_SENDING
elif (conduit.receiving_component() == instance_kernel and
conduit.sending_component() == peer_kernel):
conn_type = _ConnectionInfo(0)
else:
continue
instance_ndim = (len(instance) - len(instance_kernel))
peer_ndim = (len(peer) - len(peer_kernel))
if instance_ndim < peer_ndim:
conn_type |= _ConnectionInfo.SELF_IS_VECTOR
if instance_ndim > peer_ndim:
conn_type |= _ConnectionInfo.PEER_IS_VECTOR
# we cannot distinguish scalar-scalar vs. vector-vector
# but it does not matter for this logic :)
if conn_type & _ConnectionInfo.SELF_IS_SENDING:
connected_ports.append((
conduit.sending_port(),
conduit.receiving_port(),
conn_type))
else:
connected_ports.append((
conduit.receiving_port(),
conduit.sending_port(),
conn_type))
return connected_ports
@lru_cache(maxsize=None)
def _implementation(self, kernel: Reference) -> Optional[Implementation]:
"""Return the implementation of a kernel.
Args:
kernel: The kernel to get the implementation for.
Returns:
Implementation for the kernel, or None if not provided in the
configuration.
"""
implementation = None
for component in self._model.components:
if component.name == kernel:
implementation = component.implementation
if implementation in self._configuration.implementations:
return self._configuration.implementations[implementation]
return None