Skip to content
Snippets Groups Projects
mco.py 2.05 KiB
Newer Older
martinRenou's avatar
martinRenou committed
from traits.api import Str, Type, Bool, Int, Function
martinRenou's avatar
martinRenou committed

martinRenou's avatar
martinRenou committed
from force_bdss.ids import mco_parameter_id, factory_id
martinRenou's avatar
martinRenou committed
from force_bdss.core.data_value import DataValue
from force_bdss.api import (
    BaseMCOModel, BaseMCO, BaseMCOFactory,
    BaseMCOParameter, BaseMCOParameterFactory,
    BaseMCOCommunicator
)


class ProbeMCOModel(BaseMCOModel):
    pass


martinRenou's avatar
martinRenou committed
def run_func(*args, **kwargs):
    return []


martinRenou's avatar
martinRenou committed
class ProbeMCO(BaseMCO):
martinRenou's avatar
martinRenou committed
    run_function = Function()

martinRenou's avatar
martinRenou committed
    run_called = Bool(False)

martinRenou's avatar
martinRenou committed
    def __init__(self, factory, run_function=None, *args, **kwargs):
        if run_function is None:
            self.run_function = run_func
        super(ProbeMCO, self).__init__(self, factory, *args, **kwargs)

martinRenou's avatar
martinRenou committed
    def run(self, model):
        self.run_called = True

martinRenou's avatar
martinRenou committed
    def _run_function_default(self):
        def run_func(*args, **kwargs):
            pass
        return run_func

martinRenou's avatar
martinRenou committed

class ProbeParameter(BaseMCOParameter):
    pass


class RangedParameterFactory(BaseMCOParameterFactory):
martinRenou's avatar
martinRenou committed
    id = Str(mco_parameter_id("enthought", "test_mco", "test"))
martinRenou's avatar
martinRenou committed

    model_class = Type(ProbeParameter)


class ProbeMCOCommunicator(BaseMCOCommunicator):
    send_called = Bool(False)
    receive_called = Bool(False)

    nb_output_data_values = Int(0)

    def send_to_mco(self, model, kpi_results):
        self.send_called = True

    def receive_from_mco(self, model):
        self.receive_called = True
        return [
            DataValue() for _ in range(self.nb_output_data_values)
        ]


class ProbeMCOFactory(BaseMCOFactory):
martinRenou's avatar
martinRenou committed
    id = Str(factory_id("enthought", "test_mco"))
martinRenou's avatar
martinRenou committed

    model_class = Type(ProbeMCOModel)

    communicator_class = Type(ProbeMCOCommunicator)

    mco_class = Type(ProbeMCO)

    def create_model(self, model_data=None):
martinRenou's avatar
martinRenou committed
        if model_data is None:
            model_data = {}
martinRenou's avatar
martinRenou committed
        return self.model_class(
            self,
            **model_data
        )

    def create_communicator(self):
        return self.communicator_class(self)

    def create_optimizer(self):
        return self.mco_class(self)

    def parameter_factories(self):
        return []