Skip to content
Snippets Groups Projects
Commit ce9b4bc3 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

name_basis -> pymbolic_basis

parent d388acca
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ from dune.perftool.pdelab.argument import (pymbolic_apply_function, ...@@ -7,7 +7,7 @@ from dune.perftool.pdelab.argument import (pymbolic_apply_function,
pymbolic_trialfunction, pymbolic_trialfunction,
pymbolic_trialfunction_gradient, pymbolic_trialfunction_gradient,
) )
from dune.perftool.pdelab.basis import (name_basis, from dune.perftool.pdelab.basis import (pymbolic_basis,
name_reference_gradient, name_reference_gradient,
) )
from dune.perftool.pdelab.geometry import (dimension_iname, from dune.perftool.pdelab.geometry import (dimension_iname,
...@@ -25,7 +25,7 @@ from dune.perftool.pdelab.parameter import (cell_parameter_function, ...@@ -25,7 +25,7 @@ from dune.perftool.pdelab.parameter import (cell_parameter_function,
from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight, from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight,
quadrature_inames, quadrature_inames,
) )
from dune.perftool.pdelab.spaces import (lfs_iname, from dune.perftool.pdelab.spaces import (lfs_inames,
) )
...@@ -45,15 +45,15 @@ class PDELabInterface(object): ...@@ -45,15 +45,15 @@ class PDELabInterface(object):
# Local function space related generator functions # Local function space related generator functions
# #
def lfs_iname(self, element, restriction, number): def lfs_inames(self, element, restriction, number=None, context=''):
return lfs_iname(element, restriction, number) return lfs_inames(element, restriction, number, context)
# #
# Test and trial function related generator functions # Test and trial function related generator functions
# #
def name_basis(self, element, restriction): def pymbolic_basis(self, element, restriction, number):
return name_basis(element, restriction) return pymbolic_basis(element, restriction, number)
def name_reference_gradient(self, element, restriction): def name_reference_gradient(self, element, restriction):
return name_reference_gradient(element, restriction) return name_reference_gradient(element, restriction)
......
...@@ -17,7 +17,6 @@ from dune.perftool.generation import (cached, ...@@ -17,7 +17,6 @@ from dune.perftool.generation import (cached,
from dune.perftool.pdelab.index import name_index from dune.perftool.pdelab.index import name_index
from dune.perftool.pdelab.basis import (evaluate_coefficient, from dune.perftool.pdelab.basis import (evaluate_coefficient,
evaluate_coefficient_gradient, evaluate_coefficient_gradient,
name_basis,
) )
from dune.perftool.pdelab.spaces import (lfs_iname, from dune.perftool.pdelab.spaces import (lfs_iname,
name_lfs_bound, name_lfs_bound,
......
""" Generators for basis evaluations """ """ Generators for basis evaluations """
from dune.perftool.generation import (cached, from dune.perftool.generation import (backend,
cached,
class_member, class_member,
generator_factory, generator_factory,
get_backend, get_backend,
...@@ -11,6 +12,7 @@ from dune.perftool.generation import (cached, ...@@ -11,6 +12,7 @@ from dune.perftool.generation import (cached,
) )
from dune.perftool.pdelab.spaces import (lfs_child, from dune.perftool.pdelab.spaces import (lfs_child,
lfs_iname, lfs_iname,
lfs_inames,
name_leaf_lfs, name_leaf_lfs,
name_lfs, name_lfs,
name_lfs_bound, name_lfs_bound,
...@@ -60,6 +62,7 @@ def declare_cache_temporary(element, restriction, which): ...@@ -60,6 +62,7 @@ def declare_cache_temporary(element, restriction, which):
return decl return decl
@backend(interface="evaluate_basis")
@cached @cached
def evaluate_basis(leaf_element, name, restriction): def evaluate_basis(leaf_element, name, restriction):
lfs = name_leaf_lfs(leaf_element, restriction) lfs = name_leaf_lfs(leaf_element, restriction)
...@@ -77,13 +80,14 @@ def evaluate_basis(leaf_element, name, restriction): ...@@ -77,13 +80,14 @@ def evaluate_basis(leaf_element, name, restriction):
) )
def name_basis(leaf_element, restriction): def pymbolic_basis(leaf_element, restriction, number, context=''):
assert leaf_element.num_sub_elements() == 0 assert leaf_element.num_sub_elements() == 0
# TODO name mangling!
name = "phi_{}".format(FEM_name_mangling(leaf_element)) name = "phi_{}".format(FEM_name_mangling(leaf_element))
name = restricted_name(name, restriction) name = restricted_name(name, restriction)
evaluate_basis(leaf_element, name, restriction) evaluate_basis(leaf_element, name, restriction)
return name iname, = lfs_inames(leaf_element, restriction, number, context=context)
return Subscript(Variable(name), (Variable(iname),))
@cached @cached
...@@ -138,8 +142,9 @@ def evaluate_coefficient(element, name, container, restriction, component): ...@@ -138,8 +142,9 @@ def evaluate_coefficient(element, name, container, restriction, component):
temporary_variable(name, shape=shape, shape_impl=shape_impl) temporary_variable(name, shape=shape, shape_impl=shape_impl)
lfs = name_lfs(element, restriction, component) lfs = name_lfs(element, restriction, component)
index = lfs_iname(leaf_element, restriction, context='trial') basis = pymbolic_basis(leaf_element, restriction, 0, context='trial')
basis = name_basis(leaf_element, restriction) from dune.perftool.tools import get_pymbolic_indices
index, = get_pymbolic_indices(basis)
if isinstance(sub_element, (VectorElement, TensorElement)): if isinstance(sub_element, (VectorElement, TensorElement)):
lfs = lfs_child(lfs, idims, shape=shape_as_pymbolic(shape), symmetry=element.symmetry()) lfs = lfs_child(lfs, idims, shape=shape_as_pymbolic(shape), symmetry=element.symmetry())
...@@ -148,7 +153,7 @@ def evaluate_coefficient(element, name, container, restriction, component): ...@@ -148,7 +153,7 @@ def evaluate_coefficient(element, name, container, restriction, component):
coeff = pymbolic_coefficient(container, lfs, index) coeff = pymbolic_coefficient(container, lfs, index)
assignee = Subscript(Variable(name), tuple(Variable(i) for i in idims)) assignee = Subscript(Variable(name), tuple(Variable(i) for i in idims))
reduction_expr = Product((coeff, Subscript(Variable(basis), Variable(index)))) reduction_expr = Product((coeff, basis))
instruction(expression=Reduction("sum", index, reduction_expr, allow_simultaneous=True), instruction(expression=Reduction("sum", index, reduction_expr, allow_simultaneous=True),
assignee=assignee, assignee=assignee,
forced_iname_deps=frozenset(get_backend("quad_inames")()).union(frozenset(idims)), forced_iname_deps=frozenset(get_backend("quad_inames")()).union(frozenset(idims)),
......
...@@ -212,6 +212,10 @@ def lfs_iname(element, restriction, count=None, context=''): ...@@ -212,6 +212,10 @@ def lfs_iname(element, restriction, count=None, context=''):
return _lfs_iname(element, restriction, context) return _lfs_iname(element, restriction, context)
def lfs_inames(element, restriction, count=None, context=''):
return (lfs_iname(element, restriction, count, context),)
class LFSLocalIndex(FunctionIdentifier): class LFSLocalIndex(FunctionIdentifier):
def __init__(self, lfs): def __init__(self, lfs):
self.lfs = lfs self.lfs = lfs
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
NB: Basis evaluation is only needed for the trial function argument in jacobians, as the NB: Basis evaluation is only needed for the trial function argument in jacobians, as the
multiplication withthe test function is part of the sum factorization kernel. multiplication withthe test function is part of the sum factorization kernel.
""" """
from dune.perftool.generation import (backend,
cached,
)
from dune.perftool.sumfact.amatrix import (AMatrix, from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction, basis_functions_per_direction,
name_theta, name_theta,
...@@ -16,9 +19,10 @@ from dune.perftool.loopy.buffer import initialize_buffer ...@@ -16,9 +19,10 @@ from dune.perftool.loopy.buffer import initialize_buffer
from pytools import product from pytools import product
from pymbolic.primitives import Subscript, Variable import pymbolic.primitives as p
@cached
def pymbolic_trialfunction(element, restriction, component): def pymbolic_trialfunction(element, restriction, component):
theta = name_theta() theta = name_theta()
rows = quadrature_points_per_direction() rows = quadrature_points_per_direction()
...@@ -35,4 +39,26 @@ def pymbolic_trialfunction(element, restriction, component): ...@@ -35,4 +39,26 @@ def pymbolic_trialfunction(element, restriction, component):
insn_dep = setup_theta(element, restriction, component, a_matrices) insn_dep = setup_theta(element, restriction, component, a_matrices)
var = sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep})) var = sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep}))
return Subscript(Variable(var), tuple(Variable(i) for i in quadrature_inames())) return p.Subscript(p.Variable(var), tuple(p.Variable(i) for i in quadrature_inames()))
def lfs_inames(leaf_element):
return ()
@backend(interface="evaluate_basis")
@cached
def evaluate_basis(leaf_element, name, restriction):
temporary_variable(name, shape=())
theta = name_theta()
quad_inames = quadrature_inames()
lfs_inames = lfs_inames()
assert(len(quad_inames) == len(lfs_inames))
instruction(expression=p.Product(tuple(p.Subscript(p.Variable(theta), (p.Variable(i), p.Variable(j)))
for (i,j) in zip(quad_inames, lfs_inames))
),
assignee=p.Variable(name),
forced_iname_deps=frozenset(quad_inames + lfs_inames),
forced_iname_deps_is_final=True,
)
""" Some grabbag tools """ """ Some grabbag tools """
from pymbolic.primitives import Expression, Variable, Subscript import pymbolic.primitives as p
def get_pymbolic_basename(expr): def get_pymbolic_basename(expr):
assert isinstance(expr, Expression), "Type: {}, expr: {}".format(type(expr), expr) assert isinstance(expr, p.Expression), "Type: {}, expr: {}".format(type(expr), expr)
if isinstance(expr, Variable): if isinstance(expr, p.Variable):
return expr.name return expr.name
if isinstance(expr, Subscript): if isinstance(expr, p.Subscript):
return get_pymbolic_basename(expr.aggregate) return get_pymbolic_basename(expr.aggregate)
raise NotImplementedError("Cannot determine basename of {}".format(expr)) raise NotImplementedError("Cannot determine basename of {}".format(expr))
def get_pymbolic_indices(expr):
if not isinstance(expr, p.Subscript):
return ()
if not isinstance(expr.index, tuple):
return (get_pymbolic_basename(expr.index),)
return tuple(get_pymbolic_basename(i) for i in expr.index)
\ No newline at end of file
...@@ -50,7 +50,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -50,7 +50,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# Reset some state variables that are reinitialized for each accumulation term # Reset some state variables that are reinitialized for each accumulation term
self.argshape = 0 self.argshape = 0
self.transpose_necessary = False self.transpose_necessary = False
self.inames = [] self.inames = ()
return self.call(o) return self.call(o)
...@@ -80,7 +80,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -80,7 +80,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# If this is a vector element, we need add an additional accumulation loop iname # If this is a vector element, we need add an additional accumulation loop iname
for i in range(self.argshape): for i in range(self.argshape):
self.inames.append(self.interface.dimension_iname(context='arg', count=i)) self.inames = self.inames + (self.interface.dimension_iname(context='arg', count=i),)
# For the purpose of basis evaluation, we need to take the leaf element # For the purpose of basis evaluation, we need to take the leaf element
leaf_element = element.sub_elements()[0] leaf_element = element.sub_elements()[0]
...@@ -89,13 +89,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -89,13 +89,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
raise ValueError("Gradients should have been transformed to reference gradients!!!") raise ValueError("Gradients should have been transformed to reference gradients!!!")
# Have the issued instruction depend on the iname for this localfunction space # Have the issued instruction depend on the iname for this localfunction space
iname = self.interface.lfs_iname(leaf_element, restriction, o.number()) inames = self.interface.lfs_inames(leaf_element, restriction, o.number())
self.inames.append(iname) self.inames = self.inames + inames
iname, = inames
if self.reference_grad: if self.reference_grad:
return Subscript(Variable(self.interface.name_reference_gradient(leaf_element, restriction)), (Variable(iname), 0)) return Subscript(Variable(self.interface.name_reference_gradient(leaf_element, restriction)), (Variable(iname), 0))
else: else:
return Subscript(Variable(self.interface.name_basis(leaf_element, restriction)), (Variable(iname),)) return self.interface.pymbolic_basis(leaf_element, restriction, o.number())
def coefficient(self, o): def coefficient(self, o):
# Do something different for trial function and coefficients from jacobian apply # Do something different for trial function and coefficients from jacobian apply
...@@ -198,7 +199,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -198,7 +199,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return index._value return index._value
else: else:
if index in self.dimension_indices: if index in self.dimension_indices:
self.inames.append(self.dimension_indices[index]) self.inames = self.inames + (self.dimension_indices[index],)
return Variable(self.dimension_indices[index]) return Variable(self.dimension_indices[index])
else: else:
return Variable(self.interface.name_index(index)) return Variable(self.interface.name_index(index))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment