Skip to content
Snippets Groups Projects
workflow_reader.py 7.84 KiB
Newer Older

from traits.api import HasStrictTraits, Instance

from force_bdss.core.execution_layer import ExecutionLayer
from force_bdss.core.input_slot_info import InputSlotInfo
from force_bdss.core.output_slot_info import OutputSlotInfo
Stefano Borini's avatar
Stefano Borini committed
from force_bdss.core.workflow import Workflow
from ..factory_registry_plugin import IFactoryRegistryPlugin
logger = logging.getLogger(__name__)
SUPPORTED_FILE_VERSIONS = ["1"]
class BaseWorkflowReaderException(Exception):
    """Base exception for the reader errors."""
class InvalidFileException(BaseWorkflowReaderException):
    """Raised for a generic file being invalid for some
    reason, e.g. incorrect format or missing keys.
    """


class InvalidVersionException(BaseWorkflowReaderException):
    """Raised if the version tag does not satisfy the currently
    supported list."""
class MissingPluginException(BaseWorkflowReaderException):
    """Raised if the file requires a plugin we cannot find."""


class ModelInstantiationFailedException(BaseWorkflowReaderException):
    """Raised if we can't instantiate the model from a plugin"""


class WorkflowReader(HasStrictTraits):
    """
    Reads the workflow from a file.
    """
    #: The Factory registry. The reader needs it to create the
    #: specific model objects.
    factory_registry = Instance(IFactoryRegistryPlugin)
                 factory_registry,
        """Initializes the reader.

        Parameters
        ----------
        factory_registry: FactoryRegistryPlugin
            The factory registry that provides lookup services
            for a factory identified by a given id.
        self.factory_registry = factory_registry

        super(WorkflowReader, self).__init__(*args, **kwargs)

    def read(self, file):
        """Reads the file and returns a Workflow object.
        If any problem is found, raises an InvalidFileException or a
        derived, more specialized exception.

        Parameters
        ----------
        file: File
            A file object containing the data of the workflow in the
            appropriate json format.

        Returns
        -------
        Workflow
            An instance of the model tree, rooted at Workflow.

        Raises
        ------
        InvalidFileException
            Raised if the file is corrupted or cannot be read by this reader.
        """
        json_data = json.load(file)

        try:
            version = json_data["version"]
        except KeyError:
            logger.error("File missing version information")
            raise InvalidFileException("Corrupted input file, no version"
                                       " specified")

        if version not in SUPPORTED_FILE_VERSIONS:
            logger.error(
                "File contains version {} that is not in the "
                "list of supported versions {}".format(
Stefano Borini's avatar
Stefano Borini committed
                    version, SUPPORTED_FILE_VERSIONS)
            raise InvalidVersionException(
                "File version {} not supported".format(json_data["version"]))

        wf = Workflow()

Stefano Borini's avatar
Stefano Borini committed
            wf_data = json_data["workflow"]
            wf.mco = self._extract_mco(wf_data)
            wf.execution_layers[:] = self._extract_execution_layers(wf_data)
            wf.notification_listeners[:] = \
                self._extract_notification_listeners(wf_data)
            logger.exception("Could not read file {}".format(file))
            raise InvalidFileException(
                "Could not read file. "
                "Unable to find key {}."
                "It might be corrupted or unsupported."
                "key may be missing or broken.".format(e)
Stefano Borini's avatar
Stefano Borini committed
        return wf
    def _extract_mco(self, wf_data):
        """Extracts the MCO from the workflow dictionary data.

        Parameters
        ----------
        wf_data: dict
            the content of the workflow key in the top level dictionary data.

        Returns
        -------
        a BaseMCOModel instance of the specific MCO driver, or None
        if no MCO is specified in the file (as in the case of premature
        saving).
        """
        registry = self.factory_registry
        mco_data = wf_data.get("mco")
        if mco_data is None:
            # The file was saved without setting an MCO.
            # The file is valid, we simply can't run any optimization yet.
            return None

        mco_id = mco_data["id"]
        try:
            mco_factory = registry.mco_factory_by_id(mco_id)
        except KeyError:
            raise MissingPluginException(
                "Could not read file. "
                "The plugin responsible for the missing "
                "key '{}' may be missing or broken.".format(mco_id)
            )
        model_data = wf_data["mco"]["model_data"]
        model_data["parameters"] = self._extract_mco_parameters(
            mco_id,
            model_data["parameters"])

        try:
            model = mco_factory.create_model(model_data)
        except Exception as e:
Stefano Borini's avatar
Stefano Borini committed
            logger.exception("Unable to create model for MCO {}".format(
                mco_id))
            raise ModelInstantiationFailedException(
                "Unable to create model for MCO {}: {}".format(mco_id, e))
Stefano Borini's avatar
Stefano Borini committed
        return model
    def _extract_execution_layers(self, wf_data):
        """Extracts the data sources from the workflow dictionary data.
        Parameters
        ----------
        wf_data: dict
            the content of the workflow key in the top level dictionary data.

        Returns
        -------
        list of BaseDataSourceModel instances. Each BaseDataSourceModel is an
        instance of the specific model class. The list can be empty.
        registry = self.factory_registry
        layers = []
        for el_entry in wf_data["execution_layers"]:

            for ds_entry in el_entry:
                ds_id = ds_entry["id"]
                ds_factory = registry.data_source_factory_by_id(ds_id)
                model_data = ds_entry["model_data"]
                model_data["input_slot_info"] = self._extract_input_slot_info(
                    model_data["input_slot_info"]
                model_data["output_slot_info"] = \
                    self._extract_output_slot_info(
                        model_data["output_slot_info"]
                    )
                    ds_factory.create_model(model_data))
            layers.append(layer)

        return layers
    def _extract_mco_parameters(self, mco_id, parameters_data):
Stefano Borini's avatar
Stefano Borini committed
        """Extracts the MCO parameters from the data as dictionary.

        Parameters
        ----------
        parameters_data: dict
            The content of the parameter data key in the MCO model data.

        Returns
        -------
        List of instances of a subclass of BaseMCOParameter
        """
        registry = self.factory_registry

        parameters = []

        for p in parameters_data:
            id = p["id"]
            factory = registry.mco_parameter_factory_by_id(mco_id, id)
            model = factory.create_model(p["model_data"])
            parameters.append(model)
    def _extract_input_slot_info(self, info):
        return [InputSlotInfo(**d) for d in info]
    def _extract_output_slot_info(self, info):
        return [OutputSlotInfo(**d) for d in info]

    def _extract_notification_listeners(self, wf_data):
        registry = self.factory_registry
        listeners = []
        for nl_entry in wf_data["notification_listeners"]:
            nl_id = nl_entry["id"]
            nl_factory = registry.notification_listener_factory_by_id(nl_id)
            model_data = nl_entry["model_data"]
            listeners.append(nl_factory.create_model(model_data))

        return listeners