Skip to content
Snippets Groups Projects
mco.py 2.11 KiB
Newer Older
Stefano Borini's avatar
Stefano Borini committed
from traits.api import Str, Type, Bool, Int, Function, List
martinRenou's avatar
martinRenou committed

martinRenou's avatar
martinRenou committed
from force_bdss.ids import mco_parameter_id, factory_id
martinRenou's avatar
martinRenou committed
from force_bdss.core.data_value import DataValue
from force_bdss.api import (
    BaseMCOModel, BaseMCO, BaseMCOFactory,
    BaseMCOParameter, BaseMCOParameterFactory,
    BaseMCOCommunicator
)


class ProbeMCOModel(BaseMCOModel):
    #: Counts how many times the edit_traits method has been called
    edit_traits_call_count = Int(0)

    def edit_traits(self, *args, **kwargs):
        self.edit_traits_call_count += 1
martinRenou's avatar
martinRenou committed
def run_func(*args, **kwargs):
    return []


martinRenou's avatar
martinRenou committed
class ProbeMCO(BaseMCO):
    run_function = Function(default_value=run_func)
martinRenou's avatar
martinRenou committed

martinRenou's avatar
martinRenou committed
    run_called = Bool(False)

    def run(self, model):
        self.run_called = True
        return self.run_function(model)
martinRenou's avatar
martinRenou committed

martinRenou's avatar
martinRenou committed

class ProbeParameter(BaseMCOParameter):
    pass


Stefano Borini's avatar
Stefano Borini committed
class ProbeParameterFactory(BaseMCOParameterFactory):
martinRenou's avatar
martinRenou committed
    id = Str(mco_parameter_id("enthought", "test_mco", "test"))
martinRenou's avatar
martinRenou committed

    model_class = Type(ProbeParameter)


class ProbeMCOCommunicator(BaseMCOCommunicator):
    send_called = Bool(False)
    receive_called = Bool(False)

    nb_output_data_values = Int(0)

    def send_to_mco(self, model, kpi_results):
        self.send_called = True

    def receive_from_mco(self, model):
        self.receive_called = True
        return [
            DataValue() for _ in range(self.nb_output_data_values)
        ]


class ProbeMCOFactory(BaseMCOFactory):
martinRenou's avatar
martinRenou committed
    id = Str(factory_id("enthought", "test_mco"))
martinRenou's avatar
martinRenou committed

    model_class = Type(ProbeMCOModel)

    communicator_class = Type(ProbeMCOCommunicator)

    mco_class = Type(ProbeMCO)

martinRenou's avatar
martinRenou committed
    def create_model(self, model_data=None):
martinRenou's avatar
martinRenou committed
        if model_data is None:
            model_data = {}
martinRenou's avatar
martinRenou committed
        return self.model_class(
            self,
            **model_data
        )

    def create_communicator(self):
        return self.communicator_class(
            self,
            nb_output_data_values=self.nb_output_data_values)
martinRenou's avatar
martinRenou committed

    def create_optimizer(self):
        return self.mco_class(self)

    def parameter_factories(self):
        return [ProbeParameterFactory(mco_factory=self)]