# -*- coding: utf-8 -*-
"""The state machine for processes"""
import enum
import functools
import inspect
import logging
import os
import sys
from types import TracebackType
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast
from plumpy.futures import Future
from .utils import call_with_super_check, super_check
__all__ = ['StateMachine', 'StateMachineMeta', 'event', 'TransitionFailed']
_LOGGER = logging.getLogger(__name__)
LABEL_TYPE = Union[None, enum.Enum, str] # pylint: disable=invalid-name
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] # pylint: disable=invalid-name
class StateMachineError(Exception):
"""Base class for state machine errors"""
class StateEntryFailed(Exception):
"""
Failed to enter a state, can provide the next state to go to via this exception
"""
def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg
super().__init__('failed to enter state')
self.state = state
self.args = args
self.kwargs = kwargs
class InvalidStateError(Exception):
"""The operation is not allowed in this state."""
class EventError(StateMachineError):
def __init__(self, evt: str, msg: str):
super().__init__(msg)
self.event = evt
[docs]class TransitionFailed(Exception):
"""A state transition failed"""
def __init__(
self,
initial_state: 'State',
final_state: Optional['State'] = None,
traceback_str: Optional[str] = None
) -> None:
self.initial_state = initial_state
self.final_state = final_state
self.traceback_str = traceback_str
super().__init__(self._format_msg())
[docs]def event(
from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*'
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""A decorator to check for correct transitions, raising ``EventError`` on invalid transitions."""
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,) # type: ignore
if not all(issubclass(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,) # type: ignore
if not all(issubclass(state, State) for state in to_states): # type: ignore
raise TypeError(f'to_states: {to_states}')
def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
evt_label = wrapped.__name__
@functools.wraps(wrapped)
def transition(self: Any, *a: Any, **kw: Any) -> Any:
initial = self._state
if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore
raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}')
result = wrapped(self, *a, **kw)
if not (result is False or isinstance(result, Future)):
if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore
if self._state == initial:
raise EventError(evt_label, 'Machine did not transition')
raise EventError(
evt_label, 'Event produced invalid state transition from '
f'{initial.LABEL} to {self._state.LABEL}'
)
return result
return transition
if inspect.isfunction(from_states):
return wrapper(from_states) # type: ignore
return wrapper
class State:
LABEL: LABEL_TYPE = None
# A set containing the labels of states that can be entered
# from this one
ALLOWED: Set[LABEL_TYPE] = set()
@classmethod
def is_terminal(cls) -> bool:
return not cls.ALLOWED
def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): # pylint: disable=unused-argument
"""
:param state_machine: The process this state belongs to
"""
self.state_machine = state_machine
self.in_state: bool = False
def __str__(self) -> str:
return str(self.LABEL)
@property
def label(self) -> LABEL_TYPE:
""" Convenience property to get the state label """
return self.LABEL
@super_check
def enter(self) -> None:
""" Entering the state """
def execute(self) -> Optional['State']:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
@super_check
def exit(self) -> None:
""" Exiting the state """
if self.is_terminal():
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')
def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)
def do_enter(self) -> None:
call_with_super_check(self.enter)
self.in_state = True
def do_exit(self) -> None:
call_with_super_check(self.exit)
self.in_state = False
class StateEventHook(enum.Enum):
"""
Hooks that can be used to register callback at various points in the state transition
procedure. The callback will be passed a state instance whose meaning will differ depending
on the hook as commented below.
"""
ENTERING_STATE: int = 0 # State passed will be the state that is being entered
ENTERED_STATE: int = 1 # State passed will be the last state that we entered from
EXITING_STATE: int = 2 # State passed will be the next state that will be entered (or None for terminal)
[docs]class StateMachine(metaclass=StateMachineMeta):
STATES: Optional[Sequence[Type[State]]] = None
_STATES_MAP: Optional[Dict[Hashable, Type[State]]] = None
_transitioning = False
_transition_failing = False
[docs] @classmethod
def get_states_map(cls) -> Dict[Hashable, Type[State]]:
cls.__ensure_built()
assert cls._STATES_MAP is not None # required for type checking
return cls._STATES_MAP
[docs] @classmethod
def get_states(cls) -> Sequence[Type[State]]:
if cls.STATES is not None:
return cls.STATES
raise RuntimeError('States not defined')
[docs] @classmethod
def initial_state_label(cls) -> LABEL_TYPE:
cls.__ensure_built()
assert cls.STATES is not None
return cls.STATES[0].LABEL # pylint: disable=unsubscriptable-object
[docs] @classmethod
def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label] # pylint: disable=unsubscriptable-object
@classmethod
def __ensure_built(cls) -> None:
try:
# Check if it's already been built (and therefore sealed)
if cls.__getattribute__(cls, 'sealed'):
return
except AttributeError:
pass
cls.STATES = cls.get_states()
assert isinstance(cls.STATES, Iterable) # pylint: disable=isinstance-second-argument-not-valid-type
# Build the states map
cls._STATES_MAP = {}
for state_cls in cls.STATES: # pylint: disable=not-an-iterable
assert issubclass(state_cls, State)
label = state_cls.LABEL
assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" # pylint: disable=unsupported-membership-test
cls._STATES_MAP[label] = state_cls # pylint: disable=unsupported-assignment-operation
# should class initialise sealed = False?
cls.sealed = True # type: ignore
def __init__(self) -> None:
super().__init__()
self.__ensure_built()
self._state: Optional[State] = None
self._exception_handler = None # Note this appears to never be used
self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG'))))
self._transitioning = False
self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {}
@super_check
def init(self) -> None:
"""Called after entering initial state in `__call__` method of `StateMachineMeta`"""
def __str__(self) -> str:
return f'<{self.__class__.__name__}> ({self.state})'
[docs] def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
@property
def state(self) -> Optional[LABEL_TYPE]:
if self._state is None:
return None
return self._state.LABEL
[docs] def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None:
"""
Add a callback to be called on a particular state event hook.
The callback should have form fn(state_machine, hook, state)
:param hook: The state event hook
:param callback: The callback function
"""
self._event_callbacks.setdefault(hook, []).append(callback)
[docs] def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None:
if getattr(self, '_closed', False):
# if the process is closed, then all callbacks have already been removed
return None
try:
self._event_callbacks[hook].remove(callback)
except (KeyError, ValueError):
raise ValueError(f"Callback not set for hook '{hook}'")
[docs] def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
for callback in self._event_callbacks.get(hook, []):
callback(self, hook, state)
@super_check
def on_terminated(self) -> None:
""" Called when a terminal state is entered """
[docs] def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None:
assert not self._transitioning, \
'Cannot call transition_to when already transitioning state'
initial_state_label = self._state.LABEL if self._state is not None else None
label = None
try:
self._transitioning = True
# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)
label = new_state.LABEL
self._exit_current_state(new_state)
try:
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs)
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)
if self._state is not None and self._state.is_terminal():
call_with_super_check(self.on_terminated)
except Exception: # pylint: disable=broad-except
self._transitioning = False
if self._transition_failing:
raise
self._transition_failing = True
self.transition_failed(initial_state_label, label, *sys.exc_info()[1:])
finally:
self._transition_failing = False
self._transitioning = False
[docs] @staticmethod
def transition_failed(
initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType
) -> None:
"""
Called when a state transitions fails. This method can be overwritten
to change the default behaviour which is to raise the exception.
:param exception: The transition failed exception
"""
raise exception.with_traceback(trace)
[docs] def get_debug(self) -> bool:
return self._debug
[docs] def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled
[docs] def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
try:
return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object
except KeyError:
raise ValueError(f'{state_label} is not a valid state')
[docs] def _exit_current_state(self, next_state: State) -> None:
""" Exit the given state """
# If we're just being constructed we may not have a state yet to exit,
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
return # Nothing to exit
if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.do_exit()
[docs] def _enter_next_state(self, next_state: State) -> None:
last_state = self._state
self._fire_state_event(StateEventHook.ENTERING_STATE, next_state)
# Enter the new state
next_state.do_enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)
[docs] def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State:
if isinstance(state, State):
# It's already a state instance
return state
# OK, have to create it
state_cls = self._ensure_state_class(state)
return state_cls(self, *args, **kwargs)
[docs] def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]:
if inspect.isclass(state) and issubclass(state, State): # type: ignore
return cast(Type[State], state)
try:
return self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object
except KeyError:
raise ValueError(f'{state} is not a valid state')