Skip to content
Snippets Groups Projects
mco.py 1.92 KiB
Newer Older
from traits.api import Bool, Int, Function
martinRenou's avatar
martinRenou committed

from force_bdss.core.data_value import DataValue
from force_bdss.api import (
    BaseMCOModel, BaseMCO, BaseMCOFactory,
    BaseMCOParameter, BaseMCOParameterFactory,
    BaseMCOCommunicator
)


class ProbeMCOModel(BaseMCOModel):
    #: Counts how many times the edit_traits method has been called
    edit_traits_call_count = Int(0)

    def edit_traits(self, *args, **kwargs):
        self.edit_traits_call_count += 1
martinRenou's avatar
martinRenou committed
def run_func(*args, **kwargs):
    return []


martinRenou's avatar
martinRenou committed
class ProbeMCO(BaseMCO):
    run_function = Function(default_value=run_func)
martinRenou's avatar
martinRenou committed

martinRenou's avatar
martinRenou committed
    run_called = Bool(False)

    def run(self, model):
        self.run_called = True
        return self.run_function(model)
martinRenou's avatar
martinRenou committed

martinRenou's avatar
martinRenou committed

class ProbeParameter(BaseMCOParameter):
    pass


Stefano Borini's avatar
Stefano Borini committed
class ProbeParameterFactory(BaseMCOParameterFactory):
    def get_identifier(self):
        return "test"
martinRenou's avatar
martinRenou committed

    def get_model_class(self):
        return ProbeParameter
martinRenou's avatar
martinRenou committed


class ProbeMCOCommunicator(BaseMCOCommunicator):
    send_called = Bool(False)
    receive_called = Bool(False)

    nb_output_data_values = Int(0)

    def send_to_mco(self, model, kpi_results):
        self.send_called = True

    def receive_from_mco(self, model):
        self.receive_called = True
        return [
            DataValue() for _ in range(self.nb_output_data_values)
        ]


class ProbeMCOFactory(BaseMCOFactory):
    nb_output_data_values = Int(0)
martinRenou's avatar
martinRenou committed

    def get_identifier(self):
        return "test_mco"
martinRenou's avatar
martinRenou committed

    def get_model_class(self):
        return ProbeMCOModel
martinRenou's avatar
martinRenou committed

    def get_communicator_class(self):
        return ProbeMCOCommunicator
martinRenou's avatar
martinRenou committed

    def get_optimizer_class(self):
        return ProbeMCO
    def get_name(self):
        return "testmco"
martinRenou's avatar
martinRenou committed

    def create_communicator(self):
        return self.communicator_class(
            self,
            nb_output_data_values=self.nb_output_data_values)
martinRenou's avatar
martinRenou committed

    def parameter_factories(self):
        return [ProbeParameterFactory(mco_factory=self)]