Source code for plumpy.processes

# -*- coding: utf-8 -*-
"""The main Process module"""
import abc
import asyncio
import contextlib
import copy
import enum
import functools
import logging
import re
import sys
import time
from types import TracebackType
from typing import (
    Any,
    Awaitable,
    Callable,
    Dict,
    Generator,
    Hashable,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)
import uuid
import warnings

try:
    from aiocontextvars import ContextVar
except ModuleNotFoundError:
    from contextvars import ContextVar

from aio_pika.exceptions import ConnectionClosed
import kiwipy
import yaml

from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected

# pylint: disable=too-many-lines

__all__ = ['Process', 'ProcessSpec', 'BundleKeys', 'TransitionFailed']

_LOGGER = logging.getLogger(__name__)
PROCESS_STACK = ContextVar('process stack', default=[])


[docs]class BundleKeys: """ String keys used by the process to save its state in the state bundle. See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. """ # pylint: disable=too-few-public-methods INPUTS_RAW = 'INPUTS_RAW' INPUTS_PARSED = 'INPUTS_PARSED' OUTPUTS = 'OUTPUTS'
class ProcessStateMachineMeta(abc.ABCMeta, state_machine.StateMachineMeta): pass # Make ProcessStateMachineMeta instances (classes) YAML - able yaml.representer.Representer.add_representer(ProcessStateMachineMeta, yaml.representer.Representer.represent_name) def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator to check that the process is not closed before running the method.""" @functools.wraps(func) def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=protected-access if self._closed: raise exceptions.ClosedError('Process is closed') return func(self, *args, **kwargs) return func_wrapper
[docs]@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status') class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. A process can be in one of the following states: * CREATED * RUNNING * WAITING * FINISHED * EXCEPTED * KILLED as defined in the :class:`~plumpy.process_states.ProcessState` enum. :: ___ | v CREATED (x) --- RUNNING (x) --- FINISHED (o) | ^ / v | / WAITING (x) -- | ^ --- * -- EXCEPTED (o) * -- KILLED (o) * (o): terminal state * (x): non terminal state When a Process enters a state is always gets a corresponding message, e.g. on entering RUNNING it will receive the on_run message. These are always called immediately after that state is entered but before being executed. """ # pylint: disable=too-many-instance-attributes,too-many-public-methods # Static class stuff ###################### _spec_class = ProcessSpec # Default placeholders, will be populated in init() _stepping = False _pausing: Optional[futures.CancellableAction] = None _paused: Optional[persistence.SavableFuture] = None _killing: Optional[futures.CancellableAction] = None _interrupt_action: Optional[futures.CancellableAction] = None _closed = False _cleanups: Optional[List[Callable[[], None]]] = None __called: bool = False
[docs] @classmethod def current(cls) -> Optional['Process']: """ Get the currently running process i.e. the one at the top of the stack :return: the currently running process """ if PROCESS_STACK.get(): return PROCESS_STACK.get()[-1] return None
[docs] @classmethod def get_states(cls) -> Sequence[Type[process_states.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( state_classes[process_states.ProcessState.CREATED], *[state for state in state_classes.values() if state.LABEL != process_states.ProcessState.CREATED] )
[docs] @classmethod def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, process_states.ProcessState.RUNNING: process_states.Running, process_states.ProcessState.WAITING: process_states.Waiting, process_states.ProcessState.FINISHED: process_states.Finished, process_states.ProcessState.EXCEPTED: process_states.Excepted, process_states.ProcessState.KILLED: process_states.Killed }
[docs] @classmethod def spec(cls) -> ProcessSpec: try: return cls.__getattribute__(cls, '_spec') except AttributeError: try: cls._spec: ProcessSpec = cls._spec_class() # type: ignore cls.__called: bool = False # type: ignore cls.define(cls._spec) # type: ignore assert cls.__called, ( f'Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in ' 'your define? Try: super().define(spec)' ) return cls._spec # type: ignore except Exception: del cls._spec # type: ignore cls.__called = False raise
[docs] @classmethod def get_name(cls) -> str: """Return the process class name.""" return cls.__name__
[docs] @classmethod def define(cls, _spec: ProcessSpec) -> None: """Define the specification of the process. Normally should be overridden by subclasses. """ cls.__called = True
[docs] @classmethod def get_description(cls) -> Dict[str, Any]: """ Get a human readable description of what this :class:`Process` does. :return: The description. """ description: Dict[str, Any] = {} if cls.__doc__: description['description'] = cls.__doc__.strip() spec_description = cls.spec().get_description() if spec_description: description['spec'] = spec_description return description
[docs] @classmethod def recreate_from( cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None ) -> 'Process': """ Recreate a process from a saved state, passing any positional and keyword arguments on to load_instance_state :param saved_state: The saved state to load from :param load_context: The load context to use :return: An instance of the object with its state loaded from the save state. """ process = cast(Process, super().recreate_from(saved_state, load_context)) call_with_super_check(process.init) return process
def __init__( self, inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, communicator: Optional[kiwipy.Communicator] = None ) -> None: """ The signature of the constructor should not be changed by subclassing processes. :param inputs: A dictionary of the process inputs :param pid: The process ID, can be manually set, if not a unique pid will be chosen :param logger: An optional logger for the process to use :param loop: The event loop :param communicator: The (optional) communicator """ super().__init__() # Don't allow the spec to be changed anymore self.spec().seal() self._loop = loop if loop is not None else asyncio.get_event_loop() self._setup_event_hooks() self._status: Optional[str] = None # May hold a current status message self._pre_paused_status: Optional[ str] = None # Save status when a pause message replaces it, such that it can be restored self._paused = None # Input/output self._raw_inputs = None if inputs is None else utils.AttributesFrozendict(inputs) self._pid = pid self._parsed_inputs: Optional[utils.AttributesFrozendict] = None self._outputs: Dict[str, Any] = {} self._uuid: Optional[uuid.UUID] = None self._creation_time: Optional[float] = None # Runtime variables self._future = persistence.SavableFuture(loop=self._loop) self.__event_helper = utils.EventHelper(ProcessListener) self._logger = logger self._communicator = communicator @super_check def init(self) -> None: """Common initialisation logic, after create or load, goes here. This method is called in :class:`plumpy.base.state_machine.StateMachineMeta` """ self._cleanups = [] # a list of functions to be ran on terminated if self._communicator is not None: try: identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) except kiwipy.TimeoutError: self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) try: # filter out state change broadcasts subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except kiwipy.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) if not self._future.done(): def try_killing(future: futures.Future) -> None: if future.cancelled(): if not self.kill('Killed by future being cancelled'): self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) self._future.add_done_callback(try_killing)
[docs] def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering(cast(process_states.State, state)), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered(cast(Optional[process_states.State], from_state)), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting() } for hook, callback in event_hooks.items(): self.add_state_event_callback(hook, callback)
@property def creation_time(self) -> Optional[float]: """ The creation time of this Process as returned by time.time() when instantiated :return: The creation time """ return self._creation_time @property def pid(self) -> Optional[PID_TYPE]: """Return the pid of the process.""" return self._pid @property def uuid(self) -> Optional[uuid.UUID]: """Return the UUID of the process """ return self._uuid @property def raw_inputs(self) -> Optional[utils.AttributesFrozendict]: """The `AttributesFrozendict` of inputs (if not None).""" return self._raw_inputs @property def inputs(self) -> Optional[utils.AttributesFrozendict]: """Return the parsed inputs.""" return self._parsed_inputs @property def outputs(self) -> Dict[str, Any]: """ Get the current outputs emitted by the Process. These may grow over time as the process runs. :return: A mapping of {output_port: value} outputs """ return self._outputs @property def logger(self) -> logging.Logger: """Return the logger for this class. If not set, return the default logger. :return: The logger. """ if self._logger is not None: return self._logger return _LOGGER @property def status(self) -> Optional[str]: """Return the status massage of the process.""" return self._status
[docs] def set_status(self, status: Optional[str]) -> None: """Set the status message of the process.""" self._status = status
@property def paused(self) -> bool: """Return whether the process was being paused.""" return self._paused is not None
[docs] def future(self) -> persistence.SavableFuture: """Return a savable future representing an eventual result of an asynchronous operation. The result is set at the terminal state. """ return self._future
[docs] @ensure_not_closed def launch( self, process_class: Type['Process'], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None ) -> 'Process': """Start running the nested process. The process is started asynchronously, without blocking other task in the event loop. """ process = process_class(inputs=inputs, pid=pid, logger=logger, loop=self.loop, communicator=self._communicator) self.loop.create_task(process.step_until_terminated()) return process
# region State introspection methods
[docs] def has_terminated(self) -> bool: """Return whether the process was terminated.""" return self._state.is_terminal()
[docs] def result(self) -> Any: """ Get the result from the process if it is finished. If the process was killed then a KilledError will be raise. If the process has excepted then the failing exception will be raised. If in any other state this will raise an InvalidStateError. :return: The result of the process """ if isinstance(self._state, process_states.Finished): return self._state.result if isinstance(self._state, process_states.Killed): raise exceptions.KilledError(self._state.msg) if isinstance(self._state, process_states.Excepted): raise (self._state.exception or Exception('process excepted')) raise exceptions.InvalidStateError
[docs] def successful(self) -> bool: """ Returns whether the result of the process is considered successful Will raise if the process is not in the FINISHED state """ try: return self._state.successful # type: ignore except AttributeError as exception: raise exceptions.InvalidStateError('process is not in the finished state') from exception
@property def is_successful(self) -> bool: """Return whether the result of the process is considered successful. :return: boolean, True if the process is in `Finished` state with `successful` attribute set to `True` """ try: return self._state.successful # type: ignore except AttributeError: return False
[docs] def killed(self) -> bool: """Return whether the process is killed.""" return self.state == process_states.ProcessState.KILLED
[docs] def killed_msg(self) -> Optional[str]: """Return the killed message.""" if isinstance(self._state, process_states.Killed): return self._state.msg raise exceptions.InvalidStateError('Has not been killed')
[docs] def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" if isinstance(self._state, process_states.Excepted): return self._state.exception return None
[docs] def done(self) -> bool: """Return True if the call was successfully killed or finished running. .. deprecated:: 0.18.6 Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) # pylint: disable=no-member return self._state.is_terminal()
# endregion # region loop methods @property def loop(self) -> asyncio.AbstractEventLoop: """Return the event loop of the process.""" return self._loop
[docs] def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> events.ProcessCallback: """ Schedule a callback to what is considered an internal process function (this needn't be a method). If it raises an exception it will cause the process to fail. """ args = (callback,) + args handle = events.ProcessCallback(self, self._run_task, args, kwargs) self.loop.create_task(handle.run()) return handle
[docs] def callback_excepted( self, _callback: Callable[..., Any], exception: Optional[BaseException], trace: Optional[TracebackType] ) -> None: if self.state != process_states.ProcessState.EXCEPTED: self.fail(exception, trace)
[docs] @contextlib.contextmanager def _process_scope(self) -> Generator[None, None, None]: """ This context manager function is used to make sure the process stack is correct meaning that globally someone can ask for Process.current() to get the last process that is on the call stack. """ stack_copy = PROCESS_STACK.get().copy() stack_copy.append(self) PROCESS_STACK.set(stack_copy) try: yield None finally: assert Process.current() is self, ( 'Somehow, the process at the top of the stack is not me, but another process! ' f'({self} != {Process.current()})' ) stack_copy = PROCESS_STACK.get().copy() stack_copy.pop() PROCESS_STACK.set(stack_copy)
[docs] async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. :param callback: A function or coroutine :param args: Optional positional arguments passed to fn :param kwargs: Optional keyword arguments passed to fn :return: The value as returned by fn """ # Make sure execute is a coroutine coro = utils.ensure_coroutine(callback) with self._process_scope(): result = await coro(*args, **kwargs) return result
# endregion # region Persistence
[docs] def save_instance_state( self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] ) -> None: """ Ask the process to save its current instance state. :param out_state: A bundle to save the state to :param save_context: The save context """ super().save_instance_state(out_state, save_context) out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: out_state[BundleKeys.INPUTS_RAW] = self.encode_input_args(self.raw_inputs) if self.inputs is not None: out_state[BundleKeys.INPUTS_PARSED] = self.encode_input_args(self.inputs) if self.outputs: out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs)
[docs] @protected def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: """Load the process from its saved instance state. :param saved_state: A bundle to load the state from :param load_context: The load context """ # First make sure the state machine constructor is called super().__init__() self._setup_event_hooks() # Runtime variables, set initial states self._future = persistence.SavableFuture() self.__event_helper = utils.EventHelper(ProcessListener) self._logger = None self._communicator = None if 'loop' in load_context: self._loop = load_context.loop else: self._loop = asyncio.get_event_loop() self._state: process_states.State = self.recreate_state(saved_state['_state']) if 'communicator' in load_context: self._communicator = load_context.communicator if 'logger' in load_context: self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above super().load_instance_state(saved_state, load_context) # Inputs/outputs try: decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) self._raw_inputs = utils.AttributesFrozendict(decoded) except KeyError: self._raw_inputs = None try: decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) self._parsed_inputs = utils.AttributesFrozendict(decoded) except KeyError: self._parsed_inputs = None try: decoded = self.decode_input_args(saved_state[BundleKeys.OUTPUTS]) self._outputs = decoded except KeyError: self._outputs = {}
# endregion
[docs] def add_process_listener(self, listener: ProcessListener) -> None: """Add a process listener to the process. The listener defines the actions to take when the process is triggering the specific state condition. """ assert (listener != self), 'Cannot listen to yourself!' self.__event_helper.add_listener(listener)
[docs] def remove_process_listener(self, listener: ProcessListener) -> None: """Remove a process listener from the process.""" self.__event_helper.remove_listener(listener)
[docs] @protected def set_logger(self, logger: logging.Logger) -> None: """Set the logger of the process.""" self._logger = logger
[docs] @protected def log_with_pid(self, level: int, msg: str) -> None: """Log the message with the process pid.""" self.logger.log(level, '%s: %s', self.pid, msg)
# region Events
[docs] def on_entering(self, state: process_states.State) -> None: # Map these onto direct functions that the subclass can implement state_label = state.LABEL if state_label == process_states.ProcessState.CREATED: call_with_super_check(self.on_create) elif state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_run) elif state_label == process_states.ProcessState.WAITING: call_with_super_check(self.on_wait, state.data) # type: ignore elif state_label == process_states.ProcessState.FINISHED: call_with_super_check(self.on_finish, state.result, state.successful) # type: ignore elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_kill, state.msg) # type: ignore elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore
[docs] def on_entered(self, from_state: Optional[process_states.State]) -> None: # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_running) elif state_label == process_states.ProcessState.WAITING: call_with_super_check(self.on_waiting) elif state_label == process_states.ProcessState.FINISHED: call_with_super_check(self.on_finished) elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_excepted) elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None subject = f'state_changed.{from_label}.{self.state.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) except ConnectionClosed: message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) except kiwipy.TimeoutError: message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' self.logger.warning(message, self.pid, from_label, self.state.value)
[docs] def on_exiting(self) -> None: state = self.state if state == process_states.ProcessState.WAITING: call_with_super_check(self.on_exit_waiting) elif state == process_states.ProcessState.RUNNING: call_with_super_check(self.on_exit_running)
@super_check def on_create(self) -> None: """Entering the CREATED state.""" self._creation_time = time.time() # This will parse the inputs with respect to the input portnamespace of the spec and validate them raw_inputs = dict(self._raw_inputs) if self._raw_inputs else {} self._parsed_inputs = self.spec().inputs.pre_process(raw_inputs) result = self.spec().inputs.validate(self._parsed_inputs) if result is not None: raise ValueError(result) # Set up a process ID self._uuid = uuid.uuid4() if self._pid is None: self._pid = self._uuid @super_check def on_exit_running(self) -> None: """Exiting the RUNNING state.""" @super_check def on_exit_waiting(self) -> None: """Exiting the WAITING state.""" @super_check def on_run(self) -> None: """Entering the RUNNING state.""" @super_check def on_running(self) -> None: """Entered the RUNNING state.""" self._fire_event(ProcessListener.on_process_running)
[docs] def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted."""
[docs] def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
@super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: """Entering the WAITING state.""" @super_check def on_waiting(self) -> None: """Entered the WAITING state.""" self._fire_event(ProcessListener.on_process_waiting) @super_check def on_pausing(self, msg: Optional[str] = None) -> None: """The process is being paused.""" @super_check def on_paused(self, msg: Optional[str] = None) -> None: """The process was paused.""" self._pausing = None # Create a future to represent the duration of the paused state self._paused = persistence.SavableFuture() # Save the current status and potentially overwrite it with the passed message self._pre_paused_status = self.status if msg is not None: self.set_status(msg) self._fire_event(ProcessListener.on_process_paused) @super_check def on_playing(self) -> None: """The process was played.""" # Done being paused if self._paused is not None: self._paused.set_result(True) self._paused = None self.set_status(self._pre_paused_status) self._pre_paused_status = None self._fire_event(ProcessListener.on_process_played) @super_check def on_finish(self, result: Any, successful: bool) -> None: """Entering the FINISHED state.""" if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) self.future().set_result(self.outputs) @super_check def on_finished(self) -> None: """Entered the FINISHED state.""" self._fire_event(ProcessListener.on_process_finished, self.future().result()) @super_check def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: """Entering the EXCEPTED state.""" exception = exc_info[1] exception.__traceback__ = exc_info[2] self.future().set_exception(exception) @super_check def on_excepted(self) -> None: """Entered the EXCEPTED state.""" self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check def on_kill(self, msg: Optional[str]) -> None: """Entering the KILLED state.""" self.set_status(msg) self.future().set_exception(exceptions.KilledError(msg)) @super_check def on_killed(self) -> None: """Entered the KILLED state.""" self._killing = None self.future().exception() # exception must be retrieved self._fire_event(ProcessListener.on_process_killed, self.killed_msg())
[docs] def on_terminated(self) -> None: """Call when a terminal state is reached.""" super().on_terminated() self.close()
@super_check def on_close(self) -> None: """ Called when the Process is being closed an will not be ran anymore. This is an opportunity to free any runtime resources """ try: for cleanup in self._cleanups or []: try: cleanup() except Exception: # pylint: disable=broad-except self.logger.exception('Process<%s>: Exception calling cleanup method %s', self.pid, cleanup) self._cleanups = None finally: self._event_callbacks = {} self._closed = True
[docs] def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None: self.__event_helper.fire_event(evt, self, *args, **kwargs)
# endregion # region Communication
[docs] def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any: """ Coroutine called when the process receives a message from the communicator :param _comm: the communicator that sent the message :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) intent = msg[process_comms.INTENT_KEY] if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) return status_info # Didn't match any known intents raise RuntimeError('Unknown intent')
[docs] def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any) -> Optional[kiwipy.Future]: """ Coroutine called when the process receives a message from the communicator :param _comm: the communicator that sent the message :param msg: the message """ # pylint: disable=unused-argument self.logger.debug( "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body ) # If we get a message we recognise then action it, otherwise ignore if subject == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if subject == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=body) if subject == process_comms.Intent.KILL: return self._schedule_rpc(self.kill, msg=body) return None
[docs] def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: """ Schedule a call to a callback as a result of an RPC communication call, this will return a future that resolves to the final result (even after one or more layer of futures being returned) of the callback. The callback will be scheduled at the working thread where the process event loop runs. :param callback: the callback function or coroutine :param args: the positional arguments to the callback :param kwargs: the keyword arguments to the callback :return: a kiwi future that resolves to the outcome of the callback """ kiwi_future = kiwipy.Future() async def run_callback() -> None: with kiwipy.capture_exceptions(kiwi_future): result = callback(*args, **kwargs) while asyncio.isfuture(result): result = await result kiwi_future.set_result(result) # Schedule the task and give back a kiwi future asyncio.run_coroutine_threadsafe(run_callback(), self.loop) return kiwi_future
# endregion
[docs] @ensure_not_closed def add_cleanup(self, cleanup: Callable[[], None]) -> None: """Add callback, which will be run when the process is being closed.""" assert self._cleanups is not None self._cleanups.append(cleanup)
[docs] def close(self) -> None: """ Calling this method indicates that this process should not ran anymore and will trigger any runtime resources (such as the communicator connection) to be cleaned up. The state of the process will still be accessible. It is safe to call this method multiple times. """ if self._closed: return call_with_super_check(self.on_close)
# region State related methods
[docs] def transition_excepted( self, _initial_state: Any, final_state: process_states.ProcessState, exception: Exception, trace: TracebackType ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace)
[docs] def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private `_pre_paused_status attribute`, such that it can be restored when the process is played again. :return: False if process is already terminated, True if already paused or pausing, a `CancellableAction` to pause if the process was running steps """ if self.has_terminated(): return False if self.paused: # Already paused return True if self._pausing is not None: # Already pausing return self._pausing if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg) self._set_interrupt_action_from_exception(interrupt_exception) self._pausing = self._interrupt_action # Try to interrupt the state self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) return self._do_pause(msg)
[docs] def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: """ Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: self.transition_to(next_state) call_with_super_check(self.on_pausing, state_msg) call_with_super_check(self.on_paused, state_msg) finally: self._pausing = None return True
[docs] def _create_interrupt_action(self, exception: process_states.Interruption) -> futures.CancellableAction: """ Create an interrupt action from the corresponding interrupt exception :param exception: The interrupt exception :return: The interrupt action """ if isinstance(exception, process_states.PauseInterruption): do_pause = functools.partial(self._do_pause, str(exception)) return futures.CancellableAction(do_pause, cookie=exception) if isinstance(exception, process_states.KillInterruption): def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state self.transition_to(process_states.ProcessState.KILLED, str(exception)) return True finally: self._killing = None return futures.CancellableAction(do_kill, cookie=exception) raise ValueError(f"Got unknown interruption type '{type(exception)}'")
[docs] def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set """ if self._interrupt_action is not None: self._interrupt_action.cancel() self._interrupt_action = new_action
[docs] def _set_interrupt_action_from_exception(self, interrupt_exception: process_states.Interruption) -> None: """ Set an interrupt action from the corresponding interrupt exception """ action = self._create_interrupt_action(interrupt_exception) self._set_interrupt_action(action)
[docs] def play(self) -> bool: """ Play a process. Returns True if after this call the process is playing, False otherwise :return: True if playing, False otherwise """ if not self.paused: if self._pausing is not None: # Not going to pause after all self._pausing.cancel() self._pausing = None self._set_interrupt_action(None) return True call_with_super_check(self.on_playing) return True
[docs] @event(from_states=(process_states.Waiting)) def resume(self, *args: Any) -> None: """Start running the process again.""" return self._state.resume(*args) # type: ignore
[docs] @event(to_states=process_states.Excepted) def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back)
[docs] def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message """ if self.state == process_states.ProcessState.KILLED: # Already killed return True if self.has_terminated(): # Can't kill return False if self._killing: # Already killing return self._killing if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) self.transition_to(process_states.ProcessState.KILLED, msg) return True
@property def is_killing(self) -> bool: """Return if the process is already being killed.""" return self._killing is not None # endregion
[docs] def create_initial_state(self) -> process_states.State: """This method is here to override its superclass. Automatically enter the CREATED state when the process is created. :return: A Created state """ return cast(process_states.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run))
[docs] def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: """ Create a state object from a saved state :param saved_state: The saved state :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) return cast(process_states.State, persistence.Savable.load(saved_state, load_context))
# endregion # region Execution related methods
[docs] def run(self) -> Any: """This function will be run when the process is triggered. It should be overridden by a subclass. """
[docs] @ensure_not_closed def execute(self) -> Optional[Dict[str, Any]]: """ Execute the process. This will return if the process terminates or is paused. :return: None if not terminated, otherwise `self.outputs` """ if not self.has_terminated(): self.loop.run_until_complete(self.step_until_terminated()) return self.future().result()
[docs] @ensure_not_closed async def step(self) -> None: """Run a step. The step is run synchronously with steps in its own process, and asynchronously with steps in other processes. The execute function running in this method is dependent on the state of the process. """ assert not self.has_terminated(), 'Cannot step, already terminated' if self.paused and self._paused is not None: await self._paused try: self._stepping = True next_state = None try: next_state = await self._run_task(self._state.execute) except process_states.Interruption as exception: # If the interruption was caused by a call to a Process method then there should # be an interrupt action ready to be executed, so just check if the cookie matches # that of the exception i.e. if it is the _same_ interruption. If not cancel and # build the interrupt action below if self._interrupt_action is not None: if self._interrupt_action.cookie is not exception: self._set_interrupt_action_from_exception(exception) else: self._set_interrupt_action_from_exception(exception) except KeyboardInterrupt: # pylint: disable=try-except-raise raise except asyncio.CancelledError: # pylint: disable=try-except-raise # note this re-raise is only required in python<=3.7, # where asyncio.CancelledError == concurrent.futures.CancelledError # it is encountered when the run_task is cancelled # for python>=3.8 asyncio.CancelledError does not inherit from Exception, so will not be caught below raise except Exception: # pylint: disable=broad-except # Overwrite the next state to go to excepted directly next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) self._set_interrupt_action(None) if self._interrupt_action: self._interrupt_action.run(next_state) else: # Everything nominal so transition to the next state self.transition_to(next_state) finally: self._stepping = False self._set_interrupt_action(None)
[docs] async def step_until_terminated(self) -> None: """If the process has not terminated, run the current step and wait until the step finished. This is the function run by the event loop (not ``step``). """ while not self.has_terminated(): await self.step()
# endregion
[docs] @ensure_not_closed @protected def out(self, output_port: str, value: Any) -> None: """ Record an output value for a specific output port. If the output port matches an explicitly defined Port it will be validated against that. If not it will be validated against the PortNamespace, which means it will be checked for dynamicity and whether the type of the value is valid :param output_port: the name of the output port, can be namespaced :param value: the value for the output port :raises: ValueError if the output value is not validated against the port """ self.on_output_emitting(output_port, value) namespace_separator = self.spec().namespace_separator namespace = output_port.split(namespace_separator) port_name = namespace.pop() if namespace: port_namespace = cast( ports.PortNamespace, self.spec().outputs.get_port(namespace_separator.join(namespace)) ) else: port_namespace = self.spec().outputs validation_error = None try: port = port_namespace[port_name] dynamic = False validation_error = port.validate(value) except KeyError: port = port_namespace dynamic = True validation_error = port.validate_dynamic_ports({port_name: value}) if validation_error: msg = f"Error validating output '{value}' for port '{validation_error.port}': {validation_error.message}" raise ValueError(msg) output_namespace = self._outputs for sub_space in namespace: output_namespace = output_namespace.setdefault(sub_space, {}) output_namespace[port_name] = value self.on_output_emitted(output_port, value, dynamic)
[docs] @protected def encode_input_args(self, inputs: Any) -> Any: """ Encode input arguments such that they may be saved in a :class:`plumpy.persistence.Bundle`. The encoded inputs should contain no reference to the inputs that were passed in. This often will mean making a deepcopy of the input dictionary. :param inputs: A mapping of the inputs as passed to the process :return: The encoded inputs """ # pylint: disable=no-self-use return copy.deepcopy(inputs)
[docs] @protected def decode_input_args(self, encoded: Any) -> Any: """ Decode saved input arguments as they came from the saved instance state :class:`plumpy.persistence.Bundle`. The decoded inputs should contain no reference to the encoded inputs that were passed in. This often will mean making a deepcopy of the encoded input dictionary. :param encoded: :return: The decoded input args """ # pylint: disable=no-self-use return copy.deepcopy(encoded)
[docs] def get_status_info(self, out_status_info: dict) -> None: """Return updated status information of process. :param out_status_info: the old status """ out_status_info.update({ 'ctime': self.creation_time, 'paused': self.paused, 'process_string': str(self), 'state': str(self.state), })