Source code for libmuscle.test.test_api_guard

from typing import Callable

import pytest

from libmuscle.api_guard import APIGuard


[docs] def test_no_checkpointing_support(): guard = APIGuard(False) for _ in range(3): guard.verify_reuse_instance() guard.reuse_instance_done(True) guard.verify_reuse_instance() guard.reuse_instance_done(False)
[docs] def test_final_snapshot_only(guard: APIGuard): for i in range(4): guard.verify_reuse_instance() guard.reuse_instance_done(True) guard.verify_resuming() if i == 0: guard.resuming_done(True) guard.verify_load_snapshot() guard.load_snapshot_done() else: guard.resuming_done(False) guard.verify_should_init() guard.should_init_done() guard.verify_should_save_final_snapshot() if i == 2: guard.should_save_final_snapshot_done(True) guard.verify_save_final_snapshot() guard.save_final_snapshot_done() else: guard.should_save_final_snapshot_done(False) guard.verify_reuse_instance() guard.reuse_instance_done(False)
[docs] def test_full_checkpointing(guard: APIGuard): for i in range(4): guard.verify_reuse_instance() guard.reuse_instance_done(True) guard.verify_resuming() if i == 0: guard.resuming_done(True) guard.verify_load_snapshot() guard.load_snapshot_done() else: guard.resuming_done(False) guard.verify_should_init() guard.should_init_done() for j in range(3): guard.verify_should_save_snapshot() if j != 2: guard.should_save_snapshot_done(True) guard.verify_save_snapshot() guard.save_snapshot_done() else: guard.should_save_snapshot_done(False) guard.verify_should_save_final_snapshot() if i == 2: guard.should_save_final_snapshot_done(True) guard.verify_save_final_snapshot() guard.save_final_snapshot_done() else: guard.should_save_final_snapshot_done(False) guard.verify_reuse_instance() guard.reuse_instance_done(False)
_api_guard_funs = ( (APIGuard.verify_reuse_instance, ()), (APIGuard.reuse_instance_done, (True,)), (APIGuard.verify_resuming, ()), (APIGuard.resuming_done, (True,)), (APIGuard.verify_load_snapshot, ()), (APIGuard.load_snapshot_done, ()), (APIGuard.verify_should_init, ()), (APIGuard.should_init_done, ()), (APIGuard.verify_should_save_snapshot, ()), (APIGuard.should_save_snapshot_done, (True,)), (APIGuard.verify_save_snapshot, ()), (APIGuard.save_snapshot_done, ()), (APIGuard.verify_should_save_final_snapshot, ()), (APIGuard.should_save_final_snapshot_done, (True,)), (APIGuard.verify_save_final_snapshot, ()) )
[docs] def run_until_before(guard: APIGuard, excluded: Callable) -> None: for fun, args in _api_guard_funs: if fun is excluded: break fun(guard, *args)
[docs] def check_all_raise_except(guard: APIGuard, excluded: set[Callable]) -> None: for fun, args in _api_guard_funs: if fun.__name__.startswith('verify_'): if fun not in excluded: with pytest.raises(RuntimeError): fun(guard, *args) else: fun(guard, *args)
[docs] @pytest.mark.parametrize('fun', [ APIGuard.verify_load_snapshot, APIGuard.verify_should_init, APIGuard.verify_save_snapshot, APIGuard.verify_save_final_snapshot]) def test_missing_step(guard, fun): run_until_before(guard, fun) check_all_raise_except(guard, {fun})
[docs] def test_missing_resuming(guard: APIGuard): run_until_before(guard, APIGuard.verify_resuming) check_all_raise_except(guard, {APIGuard.verify_resuming})
[docs] def test_missing_should_save_final(guard: APIGuard): run_until_before(guard, APIGuard.verify_should_save_final_snapshot) check_all_raise_except(guard, { APIGuard.verify_should_save_snapshot, APIGuard.verify_should_save_final_snapshot})
[docs] def test_double_should_save(guard: APIGuard): run_until_before(guard, APIGuard.verify_should_save_snapshot) guard.verify_should_save_snapshot() guard.should_save_snapshot_done(True) with pytest.raises(RuntimeError): guard.verify_should_save_snapshot()