Skip to content
Snippets Groups Projects
Commit 21735855 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Rip out parameter class

It makes the code generation workflow more complicated without
contributing anything substantial
parent 060560df
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from dune.perftool.generation import (function_mangler, ...@@ -6,6 +6,7 @@ from dune.perftool.generation import (function_mangler,
) )
from loopy import CallMangleInfo from loopy import CallMangleInfo
from loopy.types import to_loopy_type
import numpy as np import numpy as np
...@@ -48,3 +49,14 @@ def dune_math_manglers(kernel, name, arg_dtypes): ...@@ -48,3 +49,14 @@ def dune_math_manglers(kernel, name, arg_dtypes):
(dt,), (dt,),
(dt,) * len(arg_dtypes), (dt,) * len(arg_dtypes),
) )
@function_mangler
def get_time_function_mangler(kernel, name, arg_dtypes):
""" The getTime method is defined on local operators once they inherit from
InstationaryLocalOperatorDefaultMethods
"""
if name == "getTime":
assert(len(arg_dtypes) == 0)
from dune.perftool.loopy.target import dtype_floatingpoint
return CallMangleInfo("this->getTime", (to_loopy_type(dtype_floatingpoint()),), ())
...@@ -25,9 +25,6 @@ from dune.perftool.pdelab.geometry import (component_iname, ...@@ -25,9 +25,6 @@ from dune.perftool.pdelab.geometry import (component_iname,
) )
from dune.perftool.pdelab.index import (name_index, from dune.perftool.pdelab.index import (name_index,
) )
from dune.perftool.pdelab.parameter import (cell_parameter_function,
intersection_parameter_function,
)
from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight, from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight,
pymbolic_quadrature_position, pymbolic_quadrature_position,
quadrature_inames, quadrature_inames,
...@@ -101,16 +98,6 @@ class PDELabInterface(object): ...@@ -101,16 +98,6 @@ class PDELabInterface(object):
def pymbolic_apply_function(self, element, restriction, index): def pymbolic_apply_function(self, element, restriction, index):
return pymbolic_apply_function(self.visitor, element, restriction, index) return pymbolic_apply_function(self.visitor, element, restriction, index)
#
# Parameter function related generator functions
#
def intersection_parameter_function(self, name, expr, cellwise_constant):
return intersection_parameter_function(name, expr, cellwise_constant)
def cell_parameter_function(self, name, expr, restriction, cellwise_constant):
return cell_parameter_function(name, expr, restriction, cellwise_constant)
# #
# Tensor expression related generator functions # Tensor expression related generator functions
# #
......
...@@ -21,7 +21,6 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_test_gfs, ...@@ -21,7 +21,6 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_test_gfs,
type_trial_gfs, type_trial_gfs,
) )
from dune.perftool.pdelab.localoperator import localoperator_basename from dune.perftool.pdelab.localoperator import localoperator_basename
from dune.perftool.pdelab.parameter import parameterclass_basename
from dune.perftool.options import get_form_option from dune.perftool.options import get_form_option
...@@ -94,8 +93,7 @@ def define_localoperator(name, form_ident): ...@@ -94,8 +93,7 @@ def define_localoperator(name, form_ident):
test_gfs = name_test_gfs() test_gfs = name_test_gfs()
loptype = type_localoperator(form_ident) loptype = type_localoperator(form_ident)
ini = name_initree() ini = name_initree()
params = name_parameters(form_ident) return "{} {}({}, {}, {});".format(loptype, name, trial_gfs, test_gfs, ini)
return "{} {}({}, {}, {}, {});".format(loptype, name, trial_gfs, test_gfs, ini, params)
def name_localoperator(form_ident): def name_localoperator(form_ident):
...@@ -154,21 +152,3 @@ def name_matrixbackend(): ...@@ -154,21 +152,3 @@ def name_matrixbackend():
name = "mb" name = "mb"
define_matrixbackend(name) define_matrixbackend(name)
return name return name
def type_parameters(form_ident):
name = parameterclass_basename(form_ident)
return name
@preamble
def define_parameters(name, form_ident):
partype = type_parameters(form_ident)
return "{} {};".format(partype, name)
def name_parameters(form_ident, define=True):
name = "params_{}".format(form_ident)
if define:
define_parameters(name, form_ident)
return name
...@@ -11,8 +11,9 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs, ...@@ -11,8 +11,9 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs,
type_range, type_range,
) )
from dune.perftool.pdelab.driver.gridoperator import (name_gridoperator, from dune.perftool.pdelab.driver.gridoperator import (name_gridoperator,
name_parameters, type_gridoperator,
type_gridoperator,) name_localoperator,
)
from dune.perftool.pdelab.driver.constraints import (has_dirichlet_constraints, from dune.perftool.pdelab.driver.constraints import (has_dirichlet_constraints,
name_bctype_function, name_bctype_function,
name_constraintscontainer, name_constraintscontainer,
...@@ -51,7 +52,7 @@ def solve_instationary(): ...@@ -51,7 +52,7 @@ def solve_instationary():
@preamble @preamble
def time_loop(): def time_loop():
ini = name_initree() ini = name_initree()
params = name_parameters(get_form_ident()) lop = name_localoperator(get_form_ident())
time = name_time() time = name_time()
element = get_trial_element() element = get_trial_element()
vector_type = type_vector(get_form_ident()) vector_type = type_vector(get_form_ident())
...@@ -67,7 +68,7 @@ def time_loop(): ...@@ -67,7 +68,7 @@ def time_loop():
assemble_new_constraints = (" // Assemble constraints for new time step\n" assemble_new_constraints = (" // Assemble constraints for new time step\n"
" {}.setTime({}+dt);\n" " {}.setTime({}+dt);\n"
" Dune::PDELab::constraints({}, {}, {});\n" " Dune::PDELab::constraints({}, {}, {});\n"
"\n".format(params, time, bctype, gfs, cc) "\n".format(lop, time, bctype, gfs, cc)
) )
# Choose between explicit and implicit time stepping # Choose between explicit and implicit time stepping
......
...@@ -13,8 +13,6 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling, ...@@ -13,8 +13,6 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling,
from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs, from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs,
name_leafview, name_leafview,
) )
from dune.perftool.pdelab.driver.gridoperator import (name_parameters,)
from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement
...@@ -73,11 +71,13 @@ def define_boundary_function(name, dirichlet): ...@@ -73,11 +71,13 @@ def define_boundary_function(name, dirichlet):
lambdaname, lambdaname,
) )
else: else:
params = name_parameters(get_form_ident()) from dune.perftool.pdelab.driver.gridoperator import name_localoperator
lop = name_localoperator(get_form_ident())
return "auto {} = Dune::PDELab::makeInstationaryGridFunctionFromCallable({}, {}, {});".format(name, return "auto {} = Dune::PDELab::makeInstationaryGridFunctionFromCallable({}, {}, {});".format(name,
gv, gv,
lambdaname, lambdaname,
params) lop,
)
@cached @cached
......
...@@ -28,6 +28,15 @@ class DriverUFL2PymbolicVisitor(UFL2LoopyVisitor): ...@@ -28,6 +28,15 @@ class DriverUFL2PymbolicVisitor(UFL2LoopyVisitor):
driver_using_statement("std::min") driver_using_statement("std::min")
return UFL2LoopyVisitor.min_value(self, o) return UFL2LoopyVisitor.min_value(self, o)
def coefficient(self, o):
if o.count() == 2:
from dune.perftool.pdelab.driver import get_form_ident
from dune.perftool.pdelab.driver.gridoperator import name_localoperator
lop = name_localoperator(get_form_ident())
return prim.Call(prim.Variable("{}.getTime".format(lop)), ())
else:
return UFL2LoopyVisitor.coefficient(self, o)
def ufl_to_code(expr, boundary=True): def ufl_to_code(expr, boundary=True):
# So far, we only considered this code branch on boundaries! # So far, we only considered this code branch on boundaries!
......
...@@ -14,6 +14,7 @@ from dune.perftool.generation import (backend, ...@@ -14,6 +14,7 @@ from dune.perftool.generation import (backend,
domain, domain,
dump_accumulate_timer, dump_accumulate_timer,
end_of_file, end_of_file,
function_mangler,
generator_factory, generator_factory,
get_backend, get_backend,
get_global_context_value, get_global_context_value,
...@@ -270,16 +271,7 @@ def boundary_predicates(expr, measure, subdomain_id): ...@@ -270,16 +271,7 @@ def boundary_predicates(expr, measure, subdomain_id):
visitor = get_visitor(measure, subdomain_id) visitor = get_visitor(measure, subdomain_id)
cond = visitor(subdomain_data, do_predicates=True) cond = visitor(subdomain_data, do_predicates=True)
else: else:
# Determine the name of the parameter function raise NotImplementedError("Only UFL expressions allowed in subdomain_data right now.")
cond = get_global_context_value("data").object_names[id(subdomain_data)]
# Trigger the generation of code for this thing in the parameter class
from ufl.checks import is_cellwise_constant
cellwise_constant = is_cellwise_constant(expr)
from dune.perftool.pdelab.parameter import intersection_parameter_function
intersection_parameter_function(cond, subdomain_data, cellwise_constant, t='int32')
cond = prim.Variable(cond)
predicates = predicates.union([prim.Comparison(cond, '==', subdomain_id)]) predicates = predicates.union([prim.Comparison(cond, '==', subdomain_id)])
...@@ -709,14 +701,10 @@ def generate_localoperator_kernels(operator): ...@@ -709,14 +701,10 @@ def generate_localoperator_kernels(operator):
lop_template_ansatz_gfs() lop_template_ansatz_gfs()
lop_template_test_gfs() lop_template_test_gfs()
lop_template_range_field() lop_template_range_field()
from dune.perftool.pdelab.parameter import parameterclass_basename
parameterclass_basename(operator)
# Make sure there is always the same constructor arguments (even if parameter class is empty) # Make sure there is always the same constructor arguments (even if parameter class is empty)
from dune.perftool.pdelab.localoperator import name_initree_member from dune.perftool.pdelab.localoperator import name_initree_member
name_initree_member() name_initree_member()
from dune.perftool.pdelab.parameter import name_paramclass
name_paramclass()
# Set some options! # Set some options!
from dune.perftool.pdelab.driver import isQuadrilateral from dune.perftool.pdelab.driver import isQuadrilateral
...@@ -734,10 +722,6 @@ def generate_localoperator_kernels(operator): ...@@ -734,10 +722,6 @@ def generate_localoperator_kernels(operator):
base_class('Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>' base_class('Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>'
.format(rf), classtag="operator") .format(rf), classtag="operator")
# Create set time method in parameter class
from dune.perftool.pdelab.parameter import define_set_time_method
define_set_time_method()
# Have a data structure collect the generated kernels # Have a data structure collect the generated kernels
operator_kernels = {} operator_kernels = {}
...@@ -871,7 +855,6 @@ def generate_localoperator_file(kernels, filename): ...@@ -871,7 +855,6 @@ def generate_localoperator_file(kernels, filename):
# Write the file! # Write the file!
from dune.perftool.file import generate_file from dune.perftool.file import generate_file
param = cgen_class_from_cache("parameterclass")
# TODO take the name of this thing from the UFL file # TODO take the name of this thing from the UFL file
lop = cgen_class_from_cache("operator", members=operator_methods) lop = cgen_class_from_cache("operator", members=operator_methods)
generate_file(filename, "operatorfile", [param, lop]) generate_file(filename, "operatorfile", [lop])
""" Generators for parameter functions """
from dune.perftool.generation import (class_basename,
class_member,
constructor_parameter,
generator_factory,
get_backend,
get_global_context_value,
initializer_list,
kernel_cached,
preamble,
temporary_variable
)
from dune.perftool.pdelab.geometry import (name_cell,
name_intersection,
)
from dune.perftool.pdelab.quadrature import quadrature_preamble
from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.cgen.clazz import AccessModifier
from dune.perftool.pdelab.localoperator import (class_type_from_cache,
localoperator_basename,
)
from dune.perftool.loopy.target import type_floatingpoint
from dune.perftool.options import get_form_option
from loopy.match import Writes
@class_basename(classtag="parameterclass")
def parameterclass_basename(form_ident):
lopbase = get_form_option("classname", form_ident)
return "{}Params".format(lopbase)
@class_member(classtag="operator")
def define_parameterclass(name):
_, t = class_type_from_cache("parameterclass")
constructor_parameter("{}&".format(t), name + "_", classtag="operator")
initializer_list(name, [name + "_"], classtag="operator")
return "{}& {};".format(t, name)
def name_paramclass():
ident = get_global_context_value("form_identifier")
from dune.perftool.pdelab.driver.gridoperator import name_parameters
name = name_parameters(ident, define=False)
define_parameterclass(name)
return name
@class_member(classtag="parameterclass")
def define_time(name):
initializer_list(name, ["0.0"], classtag="parameterclass")
ftype = type_floatingpoint()
return "{} {};".format(ftype, name)
def name_time():
define_time("t")
return "t"
def define_set_time_method():
define_set_time_method_parameterclass()
define_set_time_method_operator()
@class_member(classtag="operator")
def define_set_time_method_operator():
time_name = name_time()
param = name_paramclass()
ftype = type_floatingpoint()
result = ["// Set time in instationary case",
"void setTime ({} t_)".format(ftype),
"{",
" Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>::setTime(t_);".format(ftype),
" {}.setTime(t_);".format(param),
"}"
]
return result
@class_member(classtag="parameterclass")
def define_set_time_method_parameterclass():
time_name = name_time()
ftype = type_floatingpoint()
result = ["// Set time in instationary case",
"void setTime ({} t_)".format(ftype),
"{",
" {} = t_;".format(time_name),
"}"
]
return result
def combine_tree_path_argnumber(element, tree_path_int):
# Return string combining tree_path and argnumber.
subel = element.extract_subelement_component(tree_path_int)
def _flatten(x):
if isinstance(x, tuple):
return '_'.join(_flatten(i) for i in x if i != ())
else:
return str(x)
return _flatten(subel)
@class_member(classtag="parameterclass")
def define_parameter_function_class_member(name, expr, baset, shape, cell):
t = construct_nested_fieldvector(baset, shape)
geot = "E" if cell else "I"
geo = geot.lower()
result = ["template<typename {}, typename X>".format(geot),
"{} {}(const {}& {}, const X& local) const".format(t, name, geot, geo),
"{",
]
# In the case of a non-scalar parameter function, recurse into leafs
if expr.element.value_shape():
# Check that this is a VectorElement, as I have no idea how a parameter function
# over a non-vector mixed element should be well-defined in PDELab.
from ufl import VectorElement
assert isinstance(expr.element, VectorElement)
result.append(" {} result(0.0);".format(t))
from dune.perftool.ufl.execution import split_expression
for i, subexpr in enumerate(split_expression(expr)):
child_name = "{}_{}".format(name, combine_tree_path_argnumber(expr.element, i))
result.append(" result[{}] = {}({}, local);".format(i, child_name, geo))
define_parameter_function_class_member(child_name, subexpr, baset, shape[1:], cell)
result.append(" return result;")
else:
# Evaluate a scalar parameter function
if expr.is_global:
result.append(" auto x = {}.geometry().global(local);".format(geo))
else:
result.append(" auto x = local;")
result.append(" " + expr.c_expr[0])
result.append("}")
return result
@preamble
def evaluate_cellwise_constant_parameter_function(name, restriction):
param = name_paramclass()
entity = name_cell(restriction)
from dune.perftool.pdelab.geometry import name_localcenter
pos = name_localcenter()
from dune.perftool.generation.loopy import valuearg
import numpy
valuearg(name)
return 'auto {} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
entity,
pos,
)
@preamble
def evaluate_intersectionwise_constant_parameter_function(name):
# Check that this is not a volume term, as that would not be well-defined
from dune.perftool.generation import get_global_context_value
it = get_global_context_value("integral_type")
assert it is not 'cell'
param = name_paramclass()
intersection = name_intersection()
pos = name_localcenter()
from dune.perftool.generation.loopy import valuearg
import numpy
valuearg(name)
return 'auto {} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
intersection,
pos,
)
def evaluate_cell_parameter_function(name, restriction):
param = name_paramclass()
entity = name_cell(restriction)
pos = get_backend(interface="qp_in_cell")(restriction)
return quadrature_preamble('{} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
entity,
str(pos),
),
assignees=frozenset({name}),
read_variables=frozenset({get_pymbolic_basename(pos)}),
depends_on=frozenset({Writes(get_pymbolic_basename(pos))}),
)
def evaluate_intersection_parameter_function(name):
# Check that this is not a volume term, as that would not be well-defined
from dune.perftool.generation import get_global_context_value
it = get_global_context_value("integral_type")
assert it is not 'cell'
param = name_paramclass()
intersection = name_intersection()
pos = get_backend("quad_pos")()
return quadrature_preamble('{} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
intersection,
str(pos),
),
assignees=frozenset({name}),
read_variables=frozenset({get_pymbolic_basename(pos)}),
depends_on=frozenset({Writes(get_pymbolic_basename(pos))}),
)
def construct_nested_fieldvector(t, shape):
if len(shape) == 0:
return t
return 'Dune::FieldVector<{}, {}>'.format(construct_nested_fieldvector(t, shape[1:]), shape[0])
@kernel_cached
def cell_parameter_function(name, expr, restriction, cellwise_constant, t='float64'):
shape = expr.ufl_element().value_shape()
shape_impl = ('fv',) * len(shape)
from dune.perftool.loopy.target import numpy_to_cpp_dtype
t = numpy_to_cpp_dtype(t)
define_parameter_function_class_member(name, expr, t, shape, True)
if cellwise_constant:
evaluate_cellwise_constant_parameter_function(name, restriction)
else:
temporary_variable(name, shape=shape, shape_impl=shape_impl)
evaluate_cell_parameter_function(name, restriction)
@kernel_cached
def intersection_parameter_function(name, expr, cellwise_constant, t='float64'):
shape = expr.ufl_element().value_shape()
shape_impl = ('fv',) * len(shape)
from dune.perftool.loopy.target import numpy_to_cpp_dtype
t = numpy_to_cpp_dtype(t)
define_parameter_function_class_member(name, expr, t, shape, False)
if cellwise_constant:
evaluate_intersectionwise_constant_parameter_function(name)
else:
temporary_variable(name, shape=shape, shape_impl=shape_impl)
evaluate_intersection_parameter_function(name)
...@@ -14,7 +14,6 @@ from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker, ...@@ -14,7 +14,6 @@ from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker,
) )
from dune.perftool.tools import maybe_wrap_subscript from dune.perftool.tools import maybe_wrap_subscript
from dune.perftool.options import get_form_option from dune.perftool.options import get_form_option
from dune.perftool.pdelab.parameter import name_paramclass, name_time
from loopy import Reduction from loopy import Reduction
from pymbolic.primitives import (Call, from pymbolic.primitives import (Call,
...@@ -161,31 +160,11 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -161,31 +160,11 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# In this case it represents the time variable # In this case it represents the time variable
elif o.count() == 2: elif o.count() == 2:
param = name_paramclass() # The base class 'InstationaryLocalOperatorDefaultMethods' stores the time
time = name_time() # and exports it through a getter method 'getTime'
name = param + "." + time return prim.Call(prim.Variable("getTime"), ())
valuearg(name)
return Variable(name)
# Check if this is a parameter function
else: else:
raise NotImplementedError("Handling non-symbolic parameter functions is currently reevaluated!") raise NotImplementedError("General Coefficients")
# We expect all coefficients to be of type Expression!
assert isinstance(o, Expression)
# Determine the name of the parameter function
name = get_global_context_value("data").object_names[id(o)]
cellwise_constant = is_cellwise_constant(o)
# Trigger the generation of code for this thing in the parameter class
if o.on_intersection:
self.interface.intersection_parameter_function(name, o, cellwise_constant)
else:
self.interface.cell_parameter_function(name, o, self.restriction, cellwise_constant)
# And return a symbol
return Variable(name)
# #
# Handlers for all indexing related stuff # Handlers for all indexing related stuff
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment