From 21735855da552230aab443bf6798d2f3a554cd6a Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 31 Jan 2018 12:00:53 +0100 Subject: [PATCH] Rip out parameter class It makes the code generation workflow more complicated without contributing anything substantial --- python/dune/perftool/loopy/mangler.py | 12 + python/dune/perftool/pdelab/__init__.py | 13 - .../perftool/pdelab/driver/gridoperator.py | 22 +- .../perftool/pdelab/driver/instationary.py | 9 +- .../perftool/pdelab/driver/interpolate.py | 8 +- python/dune/perftool/pdelab/driver/visitor.py | 9 + python/dune/perftool/pdelab/localoperator.py | 23 +- python/dune/perftool/pdelab/parameter.py | 265 ------------------ python/dune/perftool/ufl/visitor.py | 29 +- 9 files changed, 38 insertions(+), 352 deletions(-) delete mode 100644 python/dune/perftool/pdelab/parameter.py diff --git a/python/dune/perftool/loopy/mangler.py b/python/dune/perftool/loopy/mangler.py index 297968c6..29b0d503 100644 --- a/python/dune/perftool/loopy/mangler.py +++ b/python/dune/perftool/loopy/mangler.py @@ -6,6 +6,7 @@ from dune.perftool.generation import (function_mangler, ) from loopy import CallMangleInfo +from loopy.types import to_loopy_type import numpy as np @@ -48,3 +49,14 @@ def dune_math_manglers(kernel, name, arg_dtypes): (dt,), (dt,) * len(arg_dtypes), ) + + +@function_mangler +def get_time_function_mangler(kernel, name, arg_dtypes): + """ The getTime method is defined on local operators once they inherit from + InstationaryLocalOperatorDefaultMethods + """ + if name == "getTime": + assert(len(arg_dtypes) == 0) + from dune.perftool.loopy.target import dtype_floatingpoint + return CallMangleInfo("this->getTime", (to_loopy_type(dtype_floatingpoint()),), ()) diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py index 2450b8a6..929a5879 100644 --- a/python/dune/perftool/pdelab/__init__.py +++ b/python/dune/perftool/pdelab/__init__.py @@ -25,9 +25,6 @@ from dune.perftool.pdelab.geometry import (component_iname, ) from dune.perftool.pdelab.index import (name_index, ) -from dune.perftool.pdelab.parameter import (cell_parameter_function, - intersection_parameter_function, - ) from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight, pymbolic_quadrature_position, quadrature_inames, @@ -101,16 +98,6 @@ class PDELabInterface(object): def pymbolic_apply_function(self, element, restriction, index): return pymbolic_apply_function(self.visitor, element, restriction, index) - # - # Parameter function related generator functions - # - - def intersection_parameter_function(self, name, expr, cellwise_constant): - return intersection_parameter_function(name, expr, cellwise_constant) - - def cell_parameter_function(self, name, expr, restriction, cellwise_constant): - return cell_parameter_function(name, expr, restriction, cellwise_constant) - # # Tensor expression related generator functions # diff --git a/python/dune/perftool/pdelab/driver/gridoperator.py b/python/dune/perftool/pdelab/driver/gridoperator.py index 2d08b849..372f1bca 100644 --- a/python/dune/perftool/pdelab/driver/gridoperator.py +++ b/python/dune/perftool/pdelab/driver/gridoperator.py @@ -21,7 +21,6 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_test_gfs, type_trial_gfs, ) from dune.perftool.pdelab.localoperator import localoperator_basename -from dune.perftool.pdelab.parameter import parameterclass_basename from dune.perftool.options import get_form_option @@ -94,8 +93,7 @@ def define_localoperator(name, form_ident): test_gfs = name_test_gfs() loptype = type_localoperator(form_ident) ini = name_initree() - params = name_parameters(form_ident) - return "{} {}({}, {}, {}, {});".format(loptype, name, trial_gfs, test_gfs, ini, params) + return "{} {}({}, {}, {});".format(loptype, name, trial_gfs, test_gfs, ini) def name_localoperator(form_ident): @@ -154,21 +152,3 @@ def name_matrixbackend(): name = "mb" define_matrixbackend(name) return name - - -def type_parameters(form_ident): - name = parameterclass_basename(form_ident) - return name - - -@preamble -def define_parameters(name, form_ident): - partype = type_parameters(form_ident) - return "{} {};".format(partype, name) - - -def name_parameters(form_ident, define=True): - name = "params_{}".format(form_ident) - if define: - define_parameters(name, form_ident) - return name diff --git a/python/dune/perftool/pdelab/driver/instationary.py b/python/dune/perftool/pdelab/driver/instationary.py index becb79b0..13a47e4d 100644 --- a/python/dune/perftool/pdelab/driver/instationary.py +++ b/python/dune/perftool/pdelab/driver/instationary.py @@ -11,8 +11,9 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs, type_range, ) from dune.perftool.pdelab.driver.gridoperator import (name_gridoperator, - name_parameters, - type_gridoperator,) + type_gridoperator, + name_localoperator, + ) from dune.perftool.pdelab.driver.constraints import (has_dirichlet_constraints, name_bctype_function, name_constraintscontainer, @@ -51,7 +52,7 @@ def solve_instationary(): @preamble def time_loop(): ini = name_initree() - params = name_parameters(get_form_ident()) + lop = name_localoperator(get_form_ident()) time = name_time() element = get_trial_element() vector_type = type_vector(get_form_ident()) @@ -67,7 +68,7 @@ def time_loop(): assemble_new_constraints = (" // Assemble constraints for new time step\n" " {}.setTime({}+dt);\n" " Dune::PDELab::constraints({}, {}, {});\n" - "\n".format(params, time, bctype, gfs, cc) + "\n".format(lop, time, bctype, gfs, cc) ) # Choose between explicit and implicit time stepping diff --git a/python/dune/perftool/pdelab/driver/interpolate.py b/python/dune/perftool/pdelab/driver/interpolate.py index 4985f8c8..9e910ac3 100644 --- a/python/dune/perftool/pdelab/driver/interpolate.py +++ b/python/dune/perftool/pdelab/driver/interpolate.py @@ -13,8 +13,6 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling, from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs, name_leafview, ) -from dune.perftool.pdelab.driver.gridoperator import (name_parameters,) - from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement @@ -73,11 +71,13 @@ def define_boundary_function(name, dirichlet): lambdaname, ) else: - params = name_parameters(get_form_ident()) + from dune.perftool.pdelab.driver.gridoperator import name_localoperator + lop = name_localoperator(get_form_ident()) return "auto {} = Dune::PDELab::makeInstationaryGridFunctionFromCallable({}, {}, {});".format(name, gv, lambdaname, - params) + lop, + ) @cached diff --git a/python/dune/perftool/pdelab/driver/visitor.py b/python/dune/perftool/pdelab/driver/visitor.py index 2db30433..6549c9a1 100644 --- a/python/dune/perftool/pdelab/driver/visitor.py +++ b/python/dune/perftool/pdelab/driver/visitor.py @@ -28,6 +28,15 @@ class DriverUFL2PymbolicVisitor(UFL2LoopyVisitor): driver_using_statement("std::min") return UFL2LoopyVisitor.min_value(self, o) + def coefficient(self, o): + if o.count() == 2: + from dune.perftool.pdelab.driver import get_form_ident + from dune.perftool.pdelab.driver.gridoperator import name_localoperator + lop = name_localoperator(get_form_ident()) + return prim.Call(prim.Variable("{}.getTime".format(lop)), ()) + else: + return UFL2LoopyVisitor.coefficient(self, o) + def ufl_to_code(expr, boundary=True): # So far, we only considered this code branch on boundaries! diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index f038e484..ea0dad53 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -14,6 +14,7 @@ from dune.perftool.generation import (backend, domain, dump_accumulate_timer, end_of_file, + function_mangler, generator_factory, get_backend, get_global_context_value, @@ -270,16 +271,7 @@ def boundary_predicates(expr, measure, subdomain_id): visitor = get_visitor(measure, subdomain_id) cond = visitor(subdomain_data, do_predicates=True) else: - # Determine the name of the parameter function - cond = get_global_context_value("data").object_names[id(subdomain_data)] - - # Trigger the generation of code for this thing in the parameter class - from ufl.checks import is_cellwise_constant - cellwise_constant = is_cellwise_constant(expr) - from dune.perftool.pdelab.parameter import intersection_parameter_function - intersection_parameter_function(cond, subdomain_data, cellwise_constant, t='int32') - - cond = prim.Variable(cond) + raise NotImplementedError("Only UFL expressions allowed in subdomain_data right now.") predicates = predicates.union([prim.Comparison(cond, '==', subdomain_id)]) @@ -709,14 +701,10 @@ def generate_localoperator_kernels(operator): lop_template_ansatz_gfs() lop_template_test_gfs() lop_template_range_field() - from dune.perftool.pdelab.parameter import parameterclass_basename - parameterclass_basename(operator) # Make sure there is always the same constructor arguments (even if parameter class is empty) from dune.perftool.pdelab.localoperator import name_initree_member name_initree_member() - from dune.perftool.pdelab.parameter import name_paramclass - name_paramclass() # Set some options! from dune.perftool.pdelab.driver import isQuadrilateral @@ -734,10 +722,6 @@ def generate_localoperator_kernels(operator): base_class('Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>' .format(rf), classtag="operator") - # Create set time method in parameter class - from dune.perftool.pdelab.parameter import define_set_time_method - define_set_time_method() - # Have a data structure collect the generated kernels operator_kernels = {} @@ -871,7 +855,6 @@ def generate_localoperator_file(kernels, filename): # Write the file! from dune.perftool.file import generate_file - param = cgen_class_from_cache("parameterclass") # TODO take the name of this thing from the UFL file lop = cgen_class_from_cache("operator", members=operator_methods) - generate_file(filename, "operatorfile", [param, lop]) + generate_file(filename, "operatorfile", [lop]) diff --git a/python/dune/perftool/pdelab/parameter.py b/python/dune/perftool/pdelab/parameter.py deleted file mode 100644 index c117609f..00000000 --- a/python/dune/perftool/pdelab/parameter.py +++ /dev/null @@ -1,265 +0,0 @@ -""" Generators for parameter functions """ - -from dune.perftool.generation import (class_basename, - class_member, - constructor_parameter, - generator_factory, - get_backend, - get_global_context_value, - initializer_list, - kernel_cached, - preamble, - temporary_variable - ) -from dune.perftool.pdelab.geometry import (name_cell, - name_intersection, - ) -from dune.perftool.pdelab.quadrature import quadrature_preamble -from dune.perftool.tools import get_pymbolic_basename -from dune.perftool.cgen.clazz import AccessModifier -from dune.perftool.pdelab.localoperator import (class_type_from_cache, - localoperator_basename, - ) -from dune.perftool.loopy.target import type_floatingpoint -from dune.perftool.options import get_form_option - -from loopy.match import Writes - - -@class_basename(classtag="parameterclass") -def parameterclass_basename(form_ident): - lopbase = get_form_option("classname", form_ident) - return "{}Params".format(lopbase) - - -@class_member(classtag="operator") -def define_parameterclass(name): - _, t = class_type_from_cache("parameterclass") - constructor_parameter("{}&".format(t), name + "_", classtag="operator") - initializer_list(name, [name + "_"], classtag="operator") - return "{}& {};".format(t, name) - - -def name_paramclass(): - ident = get_global_context_value("form_identifier") - from dune.perftool.pdelab.driver.gridoperator import name_parameters - name = name_parameters(ident, define=False) - define_parameterclass(name) - return name - - -@class_member(classtag="parameterclass") -def define_time(name): - initializer_list(name, ["0.0"], classtag="parameterclass") - ftype = type_floatingpoint() - return "{} {};".format(ftype, name) - - -def name_time(): - define_time("t") - return "t" - - -def define_set_time_method(): - define_set_time_method_parameterclass() - define_set_time_method_operator() - - -@class_member(classtag="operator") -def define_set_time_method_operator(): - time_name = name_time() - param = name_paramclass() - ftype = type_floatingpoint() - - result = ["// Set time in instationary case", - "void setTime ({} t_)".format(ftype), - "{", - " Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>::setTime(t_);".format(ftype), - " {}.setTime(t_);".format(param), - "}" - ] - - return result - - -@class_member(classtag="parameterclass") -def define_set_time_method_parameterclass(): - time_name = name_time() - ftype = type_floatingpoint() - - result = ["// Set time in instationary case", - "void setTime ({} t_)".format(ftype), - "{", - " {} = t_;".format(time_name), - "}" - ] - - return result - - -def combine_tree_path_argnumber(element, tree_path_int): - # Return string combining tree_path and argnumber. - subel = element.extract_subelement_component(tree_path_int) - - def _flatten(x): - if isinstance(x, tuple): - return '_'.join(_flatten(i) for i in x if i != ()) - else: - return str(x) - - return _flatten(subel) - - -@class_member(classtag="parameterclass") -def define_parameter_function_class_member(name, expr, baset, shape, cell): - t = construct_nested_fieldvector(baset, shape) - - geot = "E" if cell else "I" - geo = geot.lower() - result = ["template<typename {}, typename X>".format(geot), - "{} {}(const {}& {}, const X& local) const".format(t, name, geot, geo), - "{", - ] - - # In the case of a non-scalar parameter function, recurse into leafs - if expr.element.value_shape(): - # Check that this is a VectorElement, as I have no idea how a parameter function - # over a non-vector mixed element should be well-defined in PDELab. - from ufl import VectorElement - assert isinstance(expr.element, VectorElement) - - result.append(" {} result(0.0);".format(t)) - - from dune.perftool.ufl.execution import split_expression - for i, subexpr in enumerate(split_expression(expr)): - child_name = "{}_{}".format(name, combine_tree_path_argnumber(expr.element, i)) - result.append(" result[{}] = {}({}, local);".format(i, child_name, geo)) - define_parameter_function_class_member(child_name, subexpr, baset, shape[1:], cell) - - result.append(" return result;") - - else: - # Evaluate a scalar parameter function - if expr.is_global: - result.append(" auto x = {}.geometry().global(local);".format(geo)) - else: - result.append(" auto x = local;") - - result.append(" " + expr.c_expr[0]) - - result.append("}") - - return result - - -@preamble -def evaluate_cellwise_constant_parameter_function(name, restriction): - param = name_paramclass() - entity = name_cell(restriction) - from dune.perftool.pdelab.geometry import name_localcenter - pos = name_localcenter() - - from dune.perftool.generation.loopy import valuearg - import numpy - valuearg(name) - - return 'auto {} = {}.{}({}, {});'.format(name, - name_paramclass(), - name, - entity, - pos, - ) - - -@preamble -def evaluate_intersectionwise_constant_parameter_function(name): - # Check that this is not a volume term, as that would not be well-defined - from dune.perftool.generation import get_global_context_value - it = get_global_context_value("integral_type") - assert it is not 'cell' - - param = name_paramclass() - intersection = name_intersection() - pos = name_localcenter() - - from dune.perftool.generation.loopy import valuearg - import numpy - valuearg(name) - - return 'auto {} = {}.{}({}, {});'.format(name, - name_paramclass(), - name, - intersection, - pos, - ) - - -def evaluate_cell_parameter_function(name, restriction): - param = name_paramclass() - entity = name_cell(restriction) - pos = get_backend(interface="qp_in_cell")(restriction) - return quadrature_preamble('{} = {}.{}({}, {});'.format(name, - name_paramclass(), - name, - entity, - str(pos), - ), - assignees=frozenset({name}), - read_variables=frozenset({get_pymbolic_basename(pos)}), - depends_on=frozenset({Writes(get_pymbolic_basename(pos))}), - ) - - -def evaluate_intersection_parameter_function(name): - # Check that this is not a volume term, as that would not be well-defined - from dune.perftool.generation import get_global_context_value - it = get_global_context_value("integral_type") - assert it is not 'cell' - - param = name_paramclass() - intersection = name_intersection() - pos = get_backend("quad_pos")() - return quadrature_preamble('{} = {}.{}({}, {});'.format(name, - name_paramclass(), - name, - intersection, - str(pos), - ), - assignees=frozenset({name}), - read_variables=frozenset({get_pymbolic_basename(pos)}), - depends_on=frozenset({Writes(get_pymbolic_basename(pos))}), - ) - - -def construct_nested_fieldvector(t, shape): - if len(shape) == 0: - return t - return 'Dune::FieldVector<{}, {}>'.format(construct_nested_fieldvector(t, shape[1:]), shape[0]) - - -@kernel_cached -def cell_parameter_function(name, expr, restriction, cellwise_constant, t='float64'): - shape = expr.ufl_element().value_shape() - shape_impl = ('fv',) * len(shape) - from dune.perftool.loopy.target import numpy_to_cpp_dtype - t = numpy_to_cpp_dtype(t) - define_parameter_function_class_member(name, expr, t, shape, True) - if cellwise_constant: - evaluate_cellwise_constant_parameter_function(name, restriction) - else: - temporary_variable(name, shape=shape, shape_impl=shape_impl) - evaluate_cell_parameter_function(name, restriction) - - -@kernel_cached -def intersection_parameter_function(name, expr, cellwise_constant, t='float64'): - shape = expr.ufl_element().value_shape() - shape_impl = ('fv',) * len(shape) - from dune.perftool.loopy.target import numpy_to_cpp_dtype - t = numpy_to_cpp_dtype(t) - define_parameter_function_class_member(name, expr, t, shape, False) - if cellwise_constant: - evaluate_intersectionwise_constant_parameter_function(name) - else: - temporary_variable(name, shape=shape, shape_impl=shape_impl) - evaluate_intersection_parameter_function(name) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index bdce4c16..6b4db02f 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -14,7 +14,6 @@ from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker, ) from dune.perftool.tools import maybe_wrap_subscript from dune.perftool.options import get_form_option -from dune.perftool.pdelab.parameter import name_paramclass, name_time from loopy import Reduction from pymbolic.primitives import (Call, @@ -161,31 +160,11 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # In this case it represents the time variable elif o.count() == 2: - param = name_paramclass() - time = name_time() - name = param + "." + time - valuearg(name) - return Variable(name) - - # Check if this is a parameter function + # The base class 'InstationaryLocalOperatorDefaultMethods' stores the time + # and exports it through a getter method 'getTime' + return prim.Call(prim.Variable("getTime"), ()) else: - raise NotImplementedError("Handling non-symbolic parameter functions is currently reevaluated!") - # We expect all coefficients to be of type Expression! - assert isinstance(o, Expression) - - # Determine the name of the parameter function - name = get_global_context_value("data").object_names[id(o)] - - cellwise_constant = is_cellwise_constant(o) - - # Trigger the generation of code for this thing in the parameter class - if o.on_intersection: - self.interface.intersection_parameter_function(name, o, cellwise_constant) - else: - self.interface.cell_parameter_function(name, o, self.restriction, cellwise_constant) - - # And return a symbol - return Variable(name) + raise NotImplementedError("General Coefficients") # # Handlers for all indexing related stuff -- GitLab