Skip to content
Snippets Groups Projects
Commit 27e8905b authored by Marcel Koch's avatar Marcel Koch
Browse files

Adds localBasis function object to operator

parent 043452fa
No related branches found
No related tags found
No related merge requests found
...@@ -2,16 +2,17 @@ from dune.perftool.generation import (backend, ...@@ -2,16 +2,17 @@ from dune.perftool.generation import (backend,
kernel_cached, kernel_cached,
get_backend, get_backend,
instruction, instruction,
temporary_variable) temporary_variable,
globalarg,
class_member,
initializer_list)
from dune.perftool.tools import get_pymbolic_basename from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.pdelab.basis import (declare_cache_temporary, from dune.perftool.pdelab.basis import (declare_cache_temporary,
name_localbasis_cache) name_localbasis_cache,
type_localbasis
)
from dune.perftool.pdelab.geometry import world_dimension from dune.perftool.pdelab.geometry import world_dimension
from dune.perftool.pdelab.quadrature import pymbolic_quadrature_position_in_cell from dune.perftool.pdelab.quadrature import pymbolic_quadrature_position_in_cell
from dune.perftool.blockstructured.spaces import lfs_inames
import pymbolic.primitives as prim
@backend(interface="evaluate_basis", name="blockstructured") @backend(interface="evaluate_basis", name="blockstructured")
...@@ -20,11 +21,9 @@ def evaluate_basis(leaf_element, name, restriction): ...@@ -20,11 +21,9 @@ def evaluate_basis(leaf_element, name, restriction):
temporary_variable(name, shape=(4,), decl_method=declare_cache_temporary(leaf_element, restriction, 'Function')) temporary_variable(name, shape=(4,), decl_method=declare_cache_temporary(leaf_element, restriction, 'Function'))
cache = name_localbasis_cache(leaf_element) cache = name_localbasis_cache(leaf_element)
qp = pymbolic_quadrature_position_in_cell(restriction) qp = pymbolic_quadrature_position_in_cell(restriction)
localbasis = name_localbasis(leaf_element)
instruction(inames=get_backend("quad_inames")(), instruction(inames=get_backend("quad_inames")(),
code='{} = {}.evaluateFunction({}, lfs.finiteElement().localBasis());'.format(name, code='{} = {}.evaluateFunction({}, {});'.format(name, cache, str(qp), localbasis),
cache,
str(qp),
),
assignees=frozenset({name}), assignees=frozenset({name}),
read_variables=frozenset({get_pymbolic_basename(qp)}), read_variables=frozenset({get_pymbolic_basename(qp)}),
) )
...@@ -33,14 +32,26 @@ def evaluate_basis(leaf_element, name, restriction): ...@@ -33,14 +32,26 @@ def evaluate_basis(leaf_element, name, restriction):
@backend(interface="evaluate_grad", name="blockstructured") @backend(interface="evaluate_grad", name="blockstructured")
@kernel_cached @kernel_cached
def evaluate_reference_gradient(leaf_element, name, restriction): def evaluate_reference_gradient(leaf_element, name, restriction):
temporary_variable(name, shape=(4,1,world_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian')) temporary_variable(name, shape=(4, 1, world_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
cache = name_localbasis_cache(leaf_element) cache = name_localbasis_cache(leaf_element)
qp = pymbolic_quadrature_position_in_cell(restriction) qp = pymbolic_quadrature_position_in_cell(restriction)
localbasis = name_localbasis(leaf_element)
instruction(inames=get_backend("quad_inames")(), instruction(inames=get_backend("quad_inames")(),
code='{} = {}.evaluateJacobian({}, lfs.finiteElement().localBasis());'.format(name, code='{} = {}.evaluateJacobian({}, {});'.format(name, cache, str(qp), localbasis),
cache,
str(qp),
),
assignees=frozenset({name}), assignees=frozenset({name}),
read_variables=frozenset({get_pymbolic_basename(qp)}), read_variables=frozenset({get_pymbolic_basename(qp)}),
) )
\ No newline at end of file
@class_member(classtag="operator")
def define_localbasis(leaf_element, name):
localBasis_type = type_localbasis(leaf_element)
initializer_list(name, (), classtag="operator")
return "const {} {};".format(localBasis_type, name)
def name_localbasis(leaf_element):
name = "microElementBasis"
globalarg(name)
define_localbasis(leaf_element, name)
return name
...@@ -32,18 +32,32 @@ import pymbolic.primitives as prim ...@@ -32,18 +32,32 @@ import pymbolic.primitives as prim
from loopy import Reduction from loopy import Reduction
def type_localbasis(element): @class_member(classtag="operator")
def typedef_localbasis(element, name):
df = "typename {}::Traits::GridView::ctype".format(type_gfs(element)) df = "typename {}::Traits::GridView::ctype".format(type_gfs(element))
r = basetype_range() r = basetype_range()
dim = world_dimension() dim = world_dimension()
# if isPk(element): # if isPk(element):
# include_file("dune/localfunctions/lagrange/pk/pklocalbasis.hh", filetag="operatorfile") # include_file("dune/localfunctions/lagrange/pk/pklocalbasis.hh", filetag="operatorfile")
# return "Dune::PkLocalBasis<{}, {}, {}, {}>".format(df, r, dim, element._degree) # return "Dune::PkLocalBasis<{}, {}, {}, {}>".format(df, r, dim, element._degree)
#TODO add dg support
if isQk(element): if isQk(element):
include_file("dune/localfunctions/lagrange/qk/qklocalbasis.hh", filetag="operatorfile") include_file("dune/localfunctions/lagrange/qk/qklocalbasis.hh", filetag="operatorfile")
return "Dune::QkLocalBasis<{}, {}, {}, {}>".format(df, r, dim, element._degree) basis_type = "QkLocalBasis<{}, {}, {}, {}>".format(df, r, element._degree, dim)
#TODO add dg support else:
raise NotImplementedError("Element type not known in code generation") raise NotImplementedError("Element type not known in code generation")
return "using {} = Dune::{};".format(name, basis_type)
def type_localbasis(element):
if isPk(element):
name = "P{}_LocalBasis".format(element._degree)
elif isQk(element):
name = "Q{}_LocalBasis".format(element._degree)
else:
raise NotImplementedError("Element type not known in code generation")
typedef_localbasis(element, name)
return name
def type_localbasis_cache(element): def type_localbasis_cache(element):
......
from dune.perftool.options import get_option
from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import ReplaceExpression
from ufl.algorithms import MultiFunction
from ufl import as_ufl
from ufl.classes import JacobianInverse, JacobianDeterminant, Product, Division, Indexed
class ReplaceReferenceTransformation(MultiFunction):
def __init__(self, k):
MultiFunction.__init__(self)
self.k = k
self.visited_jit = False
def expr(self, o):
return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
#TODO abs uses c abs -> only works for ints!!!
def abs(self,o):
if isinstance(o.ufl_operands[0], JacobianDeterminant):
return Division(o, as_ufl(self.k**2))
else:
return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
def jacobian_determinant(self,o):
return Division(o, as_ufl(self.k**2))
def indexed(self, o):
expr = o.ufl_operands[0]
multiindex = o.ufl_operands[1]
if isinstance(expr, JacobianInverse):
return Product(as_ufl(self.k), Indexed(expr, multiindex))
else:
return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
@ufl_transformation(name="blockstructured")
def blockstructured(expr):
return ReplaceReferenceTransformation(get_option("number_of_blocks"))(expr)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment