From 21735855da552230aab443bf6798d2f3a554cd6a Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 31 Jan 2018 12:00:53 +0100
Subject: [PATCH] Rip out parameter class

It makes the code generation workflow more complicated without
contributing anything substantial
---
 python/dune/perftool/loopy/mangler.py         |  12 +
 python/dune/perftool/pdelab/__init__.py       |  13 -
 .../perftool/pdelab/driver/gridoperator.py    |  22 +-
 .../perftool/pdelab/driver/instationary.py    |   9 +-
 .../perftool/pdelab/driver/interpolate.py     |   8 +-
 python/dune/perftool/pdelab/driver/visitor.py |   9 +
 python/dune/perftool/pdelab/localoperator.py  |  23 +-
 python/dune/perftool/pdelab/parameter.py      | 265 ------------------
 python/dune/perftool/ufl/visitor.py           |  29 +-
 9 files changed, 38 insertions(+), 352 deletions(-)
 delete mode 100644 python/dune/perftool/pdelab/parameter.py

diff --git a/python/dune/perftool/loopy/mangler.py b/python/dune/perftool/loopy/mangler.py
index 297968c6..29b0d503 100644
--- a/python/dune/perftool/loopy/mangler.py
+++ b/python/dune/perftool/loopy/mangler.py
@@ -6,6 +6,7 @@ from dune.perftool.generation import (function_mangler,
                                       )
 
 from loopy import CallMangleInfo
+from loopy.types import to_loopy_type
 
 import numpy as np
 
@@ -48,3 +49,14 @@ def dune_math_manglers(kernel, name, arg_dtypes):
                               (dt,),
                               (dt,) * len(arg_dtypes),
                               )
+
+
+@function_mangler
+def get_time_function_mangler(kernel, name, arg_dtypes):
+    """ The getTime method is defined on local operators once they inherit from
+    InstationaryLocalOperatorDefaultMethods
+    """
+    if name == "getTime":
+        assert(len(arg_dtypes) == 0)
+        from dune.perftool.loopy.target import dtype_floatingpoint
+        return CallMangleInfo("this->getTime", (to_loopy_type(dtype_floatingpoint()),), ())
diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index 2450b8a6..929a5879 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -25,9 +25,6 @@ from dune.perftool.pdelab.geometry import (component_iname,
                                            )
 from dune.perftool.pdelab.index import (name_index,
                                         )
-from dune.perftool.pdelab.parameter import (cell_parameter_function,
-                                            intersection_parameter_function,
-                                            )
 from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight,
                                              pymbolic_quadrature_position,
                                              quadrature_inames,
@@ -101,16 +98,6 @@ class PDELabInterface(object):
     def pymbolic_apply_function(self, element, restriction, index):
         return pymbolic_apply_function(self.visitor, element, restriction, index)
 
-    #
-    # Parameter function related generator functions
-    #
-
-    def intersection_parameter_function(self, name, expr, cellwise_constant):
-        return intersection_parameter_function(name, expr, cellwise_constant)
-
-    def cell_parameter_function(self, name, expr, restriction, cellwise_constant):
-        return cell_parameter_function(name, expr, restriction, cellwise_constant)
-
     #
     # Tensor expression related generator functions
     #
diff --git a/python/dune/perftool/pdelab/driver/gridoperator.py b/python/dune/perftool/pdelab/driver/gridoperator.py
index 2d08b849..372f1bca 100644
--- a/python/dune/perftool/pdelab/driver/gridoperator.py
+++ b/python/dune/perftool/pdelab/driver/gridoperator.py
@@ -21,7 +21,6 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_test_gfs,
                                                            type_trial_gfs,
                                                            )
 from dune.perftool.pdelab.localoperator import localoperator_basename
-from dune.perftool.pdelab.parameter import parameterclass_basename
 from dune.perftool.options import get_form_option
 
 
@@ -94,8 +93,7 @@ def define_localoperator(name, form_ident):
     test_gfs = name_test_gfs()
     loptype = type_localoperator(form_ident)
     ini = name_initree()
-    params = name_parameters(form_ident)
-    return "{} {}({}, {}, {}, {});".format(loptype, name, trial_gfs, test_gfs, ini, params)
+    return "{} {}({}, {}, {});".format(loptype, name, trial_gfs, test_gfs, ini)
 
 
 def name_localoperator(form_ident):
@@ -154,21 +152,3 @@ def name_matrixbackend():
     name = "mb"
     define_matrixbackend(name)
     return name
-
-
-def type_parameters(form_ident):
-    name = parameterclass_basename(form_ident)
-    return name
-
-
-@preamble
-def define_parameters(name, form_ident):
-    partype = type_parameters(form_ident)
-    return "{} {};".format(partype, name)
-
-
-def name_parameters(form_ident, define=True):
-    name = "params_{}".format(form_ident)
-    if define:
-        define_parameters(name, form_ident)
-    return name
diff --git a/python/dune/perftool/pdelab/driver/instationary.py b/python/dune/perftool/pdelab/driver/instationary.py
index becb79b0..13a47e4d 100644
--- a/python/dune/perftool/pdelab/driver/instationary.py
+++ b/python/dune/perftool/pdelab/driver/instationary.py
@@ -11,8 +11,9 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs,
                                                            type_range,
                                                            )
 from dune.perftool.pdelab.driver.gridoperator import (name_gridoperator,
-                                                      name_parameters,
-                                                      type_gridoperator,)
+                                                      type_gridoperator,
+                                                      name_localoperator,
+                                                      )
 from dune.perftool.pdelab.driver.constraints import (has_dirichlet_constraints,
                                                      name_bctype_function,
                                                      name_constraintscontainer,
@@ -51,7 +52,7 @@ def solve_instationary():
 @preamble
 def time_loop():
     ini = name_initree()
-    params = name_parameters(get_form_ident())
+    lop = name_localoperator(get_form_ident())
     time = name_time()
     element = get_trial_element()
     vector_type = type_vector(get_form_ident())
@@ -67,7 +68,7 @@ def time_loop():
         assemble_new_constraints = ("  // Assemble constraints for new time step\n"
                                     "  {}.setTime({}+dt);\n"
                                     "  Dune::PDELab::constraints({}, {}, {});\n"
-                                    "\n".format(params, time, bctype, gfs, cc)
+                                    "\n".format(lop, time, bctype, gfs, cc)
                                     )
 
     # Choose between explicit and implicit time stepping
diff --git a/python/dune/perftool/pdelab/driver/interpolate.py b/python/dune/perftool/pdelab/driver/interpolate.py
index 4985f8c8..9e910ac3 100644
--- a/python/dune/perftool/pdelab/driver/interpolate.py
+++ b/python/dune/perftool/pdelab/driver/interpolate.py
@@ -13,8 +13,6 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling,
 from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs,
                                                            name_leafview,
                                                            )
-from dune.perftool.pdelab.driver.gridoperator import (name_parameters,)
-
 from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement
 
 
@@ -73,11 +71,13 @@ def define_boundary_function(name, dirichlet):
                                                                                       lambdaname,
                                                                                       )
     else:
-        params = name_parameters(get_form_ident())
+        from dune.perftool.pdelab.driver.gridoperator import name_localoperator
+        lop = name_localoperator(get_form_ident())
         return "auto {} = Dune::PDELab::makeInstationaryGridFunctionFromCallable({}, {}, {});".format(name,
                                                                                                       gv,
                                                                                                       lambdaname,
-                                                                                                      params)
+                                                                                                      lop,
+                                                                                                      )
 
 
 @cached
diff --git a/python/dune/perftool/pdelab/driver/visitor.py b/python/dune/perftool/pdelab/driver/visitor.py
index 2db30433..6549c9a1 100644
--- a/python/dune/perftool/pdelab/driver/visitor.py
+++ b/python/dune/perftool/pdelab/driver/visitor.py
@@ -28,6 +28,15 @@ class DriverUFL2PymbolicVisitor(UFL2LoopyVisitor):
         driver_using_statement("std::min")
         return UFL2LoopyVisitor.min_value(self, o)
 
+    def coefficient(self, o):
+        if o.count() == 2:
+            from dune.perftool.pdelab.driver import get_form_ident
+            from dune.perftool.pdelab.driver.gridoperator import name_localoperator
+            lop = name_localoperator(get_form_ident())
+            return prim.Call(prim.Variable("{}.getTime".format(lop)), ())
+        else:
+            return UFL2LoopyVisitor.coefficient(self, o)
+
 
 def ufl_to_code(expr, boundary=True):
     # So far, we only considered this code branch on boundaries!
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index f038e484..ea0dad53 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -14,6 +14,7 @@ from dune.perftool.generation import (backend,
                                       domain,
                                       dump_accumulate_timer,
                                       end_of_file,
+                                      function_mangler,
                                       generator_factory,
                                       get_backend,
                                       get_global_context_value,
@@ -270,16 +271,7 @@ def boundary_predicates(expr, measure, subdomain_id):
             visitor = get_visitor(measure, subdomain_id)
             cond = visitor(subdomain_data, do_predicates=True)
         else:
-            # Determine the name of the parameter function
-            cond = get_global_context_value("data").object_names[id(subdomain_data)]
-
-            # Trigger the generation of code for this thing in the parameter class
-            from ufl.checks import is_cellwise_constant
-            cellwise_constant = is_cellwise_constant(expr)
-            from dune.perftool.pdelab.parameter import intersection_parameter_function
-            intersection_parameter_function(cond, subdomain_data, cellwise_constant, t='int32')
-
-            cond = prim.Variable(cond)
+            raise NotImplementedError("Only UFL expressions allowed in subdomain_data right now.")
 
         predicates = predicates.union([prim.Comparison(cond, '==', subdomain_id)])
 
@@ -709,14 +701,10 @@ def generate_localoperator_kernels(operator):
     lop_template_ansatz_gfs()
     lop_template_test_gfs()
     lop_template_range_field()
-    from dune.perftool.pdelab.parameter import parameterclass_basename
-    parameterclass_basename(operator)
 
     # Make sure there is always the same constructor arguments (even if parameter class is empty)
     from dune.perftool.pdelab.localoperator import name_initree_member
     name_initree_member()
-    from dune.perftool.pdelab.parameter import name_paramclass
-    name_paramclass()
 
     # Set some options!
     from dune.perftool.pdelab.driver import isQuadrilateral
@@ -734,10 +722,6 @@ def generate_localoperator_kernels(operator):
         base_class('Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>'
                    .format(rf), classtag="operator")
 
-        # Create set time method in parameter class
-        from dune.perftool.pdelab.parameter import define_set_time_method
-        define_set_time_method()
-
     # Have a data structure collect the generated kernels
     operator_kernels = {}
 
@@ -871,7 +855,6 @@ def generate_localoperator_file(kernels, filename):
 
     # Write the file!
     from dune.perftool.file import generate_file
-    param = cgen_class_from_cache("parameterclass")
     # TODO take the name of this thing from the UFL file
     lop = cgen_class_from_cache("operator", members=operator_methods)
-    generate_file(filename, "operatorfile", [param, lop])
+    generate_file(filename, "operatorfile", [lop])
diff --git a/python/dune/perftool/pdelab/parameter.py b/python/dune/perftool/pdelab/parameter.py
deleted file mode 100644
index c117609f..00000000
--- a/python/dune/perftool/pdelab/parameter.py
+++ /dev/null
@@ -1,265 +0,0 @@
-""" Generators for parameter functions """
-
-from dune.perftool.generation import (class_basename,
-                                      class_member,
-                                      constructor_parameter,
-                                      generator_factory,
-                                      get_backend,
-                                      get_global_context_value,
-                                      initializer_list,
-                                      kernel_cached,
-                                      preamble,
-                                      temporary_variable
-                                      )
-from dune.perftool.pdelab.geometry import (name_cell,
-                                           name_intersection,
-                                           )
-from dune.perftool.pdelab.quadrature import quadrature_preamble
-from dune.perftool.tools import get_pymbolic_basename
-from dune.perftool.cgen.clazz import AccessModifier
-from dune.perftool.pdelab.localoperator import (class_type_from_cache,
-                                                localoperator_basename,
-                                                )
-from dune.perftool.loopy.target import type_floatingpoint
-from dune.perftool.options import get_form_option
-
-from loopy.match import Writes
-
-
-@class_basename(classtag="parameterclass")
-def parameterclass_basename(form_ident):
-    lopbase = get_form_option("classname", form_ident)
-    return "{}Params".format(lopbase)
-
-
-@class_member(classtag="operator")
-def define_parameterclass(name):
-    _, t = class_type_from_cache("parameterclass")
-    constructor_parameter("{}&".format(t), name + "_", classtag="operator")
-    initializer_list(name, [name + "_"], classtag="operator")
-    return "{}& {};".format(t, name)
-
-
-def name_paramclass():
-    ident = get_global_context_value("form_identifier")
-    from dune.perftool.pdelab.driver.gridoperator import name_parameters
-    name = name_parameters(ident, define=False)
-    define_parameterclass(name)
-    return name
-
-
-@class_member(classtag="parameterclass")
-def define_time(name):
-    initializer_list(name, ["0.0"], classtag="parameterclass")
-    ftype = type_floatingpoint()
-    return "{} {};".format(ftype, name)
-
-
-def name_time():
-    define_time("t")
-    return "t"
-
-
-def define_set_time_method():
-    define_set_time_method_parameterclass()
-    define_set_time_method_operator()
-
-
-@class_member(classtag="operator")
-def define_set_time_method_operator():
-    time_name = name_time()
-    param = name_paramclass()
-    ftype = type_floatingpoint()
-
-    result = ["// Set time in instationary case",
-              "void setTime ({} t_)".format(ftype),
-              "{",
-              "  Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>::setTime(t_);".format(ftype),
-              "  {}.setTime(t_);".format(param),
-              "}"
-              ]
-
-    return result
-
-
-@class_member(classtag="parameterclass")
-def define_set_time_method_parameterclass():
-    time_name = name_time()
-    ftype = type_floatingpoint()
-
-    result = ["// Set time in instationary case",
-              "void setTime ({} t_)".format(ftype),
-              "{",
-              "  {} = t_;".format(time_name),
-              "}"
-              ]
-
-    return result
-
-
-def combine_tree_path_argnumber(element, tree_path_int):
-    # Return string combining tree_path and argnumber.
-    subel = element.extract_subelement_component(tree_path_int)
-
-    def _flatten(x):
-        if isinstance(x, tuple):
-            return '_'.join(_flatten(i) for i in x if i != ())
-        else:
-            return str(x)
-
-    return _flatten(subel)
-
-
-@class_member(classtag="parameterclass")
-def define_parameter_function_class_member(name, expr, baset, shape, cell):
-    t = construct_nested_fieldvector(baset, shape)
-
-    geot = "E" if cell else "I"
-    geo = geot.lower()
-    result = ["template<typename {}, typename X>".format(geot),
-              "{} {}(const {}& {}, const X& local) const".format(t, name, geot, geo),
-              "{",
-              ]
-
-    # In the case of a non-scalar parameter function, recurse into leafs
-    if expr.element.value_shape():
-        # Check that this is a VectorElement, as I have no idea how a parameter function
-        # over a non-vector mixed element should be well-defined in PDELab.
-        from ufl import VectorElement
-        assert isinstance(expr.element, VectorElement)
-
-        result.append("  {} result(0.0);".format(t))
-
-        from dune.perftool.ufl.execution import split_expression
-        for i, subexpr in enumerate(split_expression(expr)):
-            child_name = "{}_{}".format(name, combine_tree_path_argnumber(expr.element, i))
-            result.append("  result[{}] = {}({}, local);".format(i, child_name, geo))
-            define_parameter_function_class_member(child_name, subexpr, baset, shape[1:], cell)
-
-        result.append("  return result;")
-
-    else:
-        # Evaluate a scalar parameter function
-        if expr.is_global:
-            result.append("  auto x = {}.geometry().global(local);".format(geo))
-        else:
-            result.append("  auto x = local;")
-
-        result.append("  " + expr.c_expr[0])
-
-    result.append("}")
-
-    return result
-
-
-@preamble
-def evaluate_cellwise_constant_parameter_function(name, restriction):
-    param = name_paramclass()
-    entity = name_cell(restriction)
-    from dune.perftool.pdelab.geometry import name_localcenter
-    pos = name_localcenter()
-
-    from dune.perftool.generation.loopy import valuearg
-    import numpy
-    valuearg(name)
-
-    return 'auto {} = {}.{}({}, {});'.format(name,
-                                             name_paramclass(),
-                                             name,
-                                             entity,
-                                             pos,
-                                             )
-
-
-@preamble
-def evaluate_intersectionwise_constant_parameter_function(name):
-    # Check that this is not a volume term, as that would not be well-defined
-    from dune.perftool.generation import get_global_context_value
-    it = get_global_context_value("integral_type")
-    assert it is not 'cell'
-
-    param = name_paramclass()
-    intersection = name_intersection()
-    pos = name_localcenter()
-
-    from dune.perftool.generation.loopy import valuearg
-    import numpy
-    valuearg(name)
-
-    return 'auto {} = {}.{}({}, {});'.format(name,
-                                             name_paramclass(),
-                                             name,
-                                             intersection,
-                                             pos,
-                                             )
-
-
-def evaluate_cell_parameter_function(name, restriction):
-    param = name_paramclass()
-    entity = name_cell(restriction)
-    pos = get_backend(interface="qp_in_cell")(restriction)
-    return quadrature_preamble('{} = {}.{}({}, {});'.format(name,
-                                                            name_paramclass(),
-                                                            name,
-                                                            entity,
-                                                            str(pos),
-                                                            ),
-                               assignees=frozenset({name}),
-                               read_variables=frozenset({get_pymbolic_basename(pos)}),
-                               depends_on=frozenset({Writes(get_pymbolic_basename(pos))}),
-                               )
-
-
-def evaluate_intersection_parameter_function(name):
-    # Check that this is not a volume term, as that would not be well-defined
-    from dune.perftool.generation import get_global_context_value
-    it = get_global_context_value("integral_type")
-    assert it is not 'cell'
-
-    param = name_paramclass()
-    intersection = name_intersection()
-    pos = get_backend("quad_pos")()
-    return quadrature_preamble('{} = {}.{}({}, {});'.format(name,
-                                                            name_paramclass(),
-                                                            name,
-                                                            intersection,
-                                                            str(pos),
-                                                            ),
-                               assignees=frozenset({name}),
-                               read_variables=frozenset({get_pymbolic_basename(pos)}),
-                               depends_on=frozenset({Writes(get_pymbolic_basename(pos))}),
-                               )
-
-
-def construct_nested_fieldvector(t, shape):
-    if len(shape) == 0:
-        return t
-    return 'Dune::FieldVector<{}, {}>'.format(construct_nested_fieldvector(t, shape[1:]), shape[0])
-
-
-@kernel_cached
-def cell_parameter_function(name, expr, restriction, cellwise_constant, t='float64'):
-    shape = expr.ufl_element().value_shape()
-    shape_impl = ('fv',) * len(shape)
-    from dune.perftool.loopy.target import numpy_to_cpp_dtype
-    t = numpy_to_cpp_dtype(t)
-    define_parameter_function_class_member(name, expr, t, shape, True)
-    if cellwise_constant:
-        evaluate_cellwise_constant_parameter_function(name, restriction)
-    else:
-        temporary_variable(name, shape=shape, shape_impl=shape_impl)
-        evaluate_cell_parameter_function(name, restriction)
-
-
-@kernel_cached
-def intersection_parameter_function(name, expr, cellwise_constant, t='float64'):
-    shape = expr.ufl_element().value_shape()
-    shape_impl = ('fv',) * len(shape)
-    from dune.perftool.loopy.target import numpy_to_cpp_dtype
-    t = numpy_to_cpp_dtype(t)
-    define_parameter_function_class_member(name, expr, t, shape, False)
-    if cellwise_constant:
-        evaluate_intersectionwise_constant_parameter_function(name)
-    else:
-        temporary_variable(name, shape=shape, shape_impl=shape_impl)
-        evaluate_intersection_parameter_function(name)
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index bdce4c16..6b4db02f 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -14,7 +14,6 @@ from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker,
                                                   )
 from dune.perftool.tools import maybe_wrap_subscript
 from dune.perftool.options import get_form_option
-from dune.perftool.pdelab.parameter import name_paramclass, name_time
 from loopy import Reduction
 
 from pymbolic.primitives import (Call,
@@ -161,31 +160,11 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
 
         # In this case it represents the time variable
         elif o.count() == 2:
-            param = name_paramclass()
-            time = name_time()
-            name = param + "." + time
-            valuearg(name)
-            return Variable(name)
-
-        # Check if this is a parameter function
+            # The base class 'InstationaryLocalOperatorDefaultMethods' stores the time
+            # and exports it through a getter method 'getTime'
+            return prim.Call(prim.Variable("getTime"), ())
         else:
-            raise NotImplementedError("Handling non-symbolic parameter functions is currently reevaluated!")
-            # We expect all coefficients to be of type Expression!
-            assert isinstance(o, Expression)
-
-            # Determine the name of the parameter function
-            name = get_global_context_value("data").object_names[id(o)]
-
-            cellwise_constant = is_cellwise_constant(o)
-
-            # Trigger the generation of code for this thing in the parameter class
-            if o.on_intersection:
-                self.interface.intersection_parameter_function(name, o, cellwise_constant)
-            else:
-                self.interface.cell_parameter_function(name, o, self.restriction, cellwise_constant)
-
-            # And return a symbol
-            return Variable(name)
+            raise NotImplementedError("General Coefficients")
 
     #
     # Handlers for all indexing related stuff
-- 
GitLab