diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py index 16ab4fd16d4271b6d403506a577770efd0291441..93240b5e35f9b41f72881afbd0b6f930addd62b7 100644 --- a/python/dune/perftool/sumfact/__init__.py +++ b/python/dune/perftool/sumfact/__init__.py @@ -4,6 +4,7 @@ from dune.perftool.sumfact.quadrature import (quadrature_inames, from dune.perftool.sumfact.basis import (lfs_inames, pymbolic_basis, + pymbolic_reference_gradient, pymbolic_trialfunction, pymbolic_trialfunction_gradient, ) @@ -18,6 +19,9 @@ class SumFactInterface(PDELabInterface): def pymbolic_basis(self, element, restriction, number): return pymbolic_basis(element, restriction, number) + def pymbolic_reference_gradient(self, element, restriction, number): + return pymbolic_reference_gradient(element, restriction, number) + def pymbolic_trialfunction_gradient(self, element, restriction, component): return pymbolic_trialfunction_gradient(element, restriction, component) diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 122479f9db729eb9ff84e690a9ea23e97f9bd9cd..3f44cdb0f37c15ec3de3b57c410063837c7090b5 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -198,3 +198,58 @@ def pymbolic_basis(element, restriction, number): evaluate_basis(element, name, restriction) return prim.Variable(name) + + +@backend(interface="evaluate_grad") +@cached +def evaluate_reference_gradient(element, name, restriction): + # from dune.perftool.pdelab.basis import name_leaf_lfs + # lfs = name_leaf_lfs(element, restriction) + # from dune.perftool.pdelab.spaces import name_lfs_bound + from dune.perftool.pdelab.geometry import name_dimension + temporary_variable( + name, + shape=(name_dimension(),)) + # shape=()) + quad_inames = quadrature_inames() + inames = lfs_inames(element, restriction) + assert(len(quad_inames) == len(inames)) + + theta = name_theta() + dtheta = name_dtheta() + + + # TODO WIP! + # Get geometric dimension + formdata = get_global_context_value('formdata') + dim = formdata.geometric_dimension + + for i in range(dim): + calls = [prim.Call(ColMajorAccess(theta), (prim.Variable(i), prim.Variable(j))) + for (i, j) in zip(quad_inames, inames)] + calls[i] = prim.Call(ColMajorAccess(dtheta), (prim.Variable(quad_inames[i]), prim.Variable(inames[i]))) + calls = tuple(calls) + + # assignee = prim.Subscript(prim.Variable(name), tuple(prim.Variable(0))) + assignee = prim.Subscript(prim.Variable(name), (i,)) + # assignee = prim.Variable(name) + expression = prim.Product(calls) + + instruction(assignee=assignee, + expression=expression, + forced_iname_deps=frozenset(quad_inames + inames), + forced_iname_deps_is_final=True, + ) + + +def pymbolic_reference_gradient(element, restriction, number): + assert number == 1 + # TODO ? + #assert element.num_sum_elements() == 0 + + # TODO: Change name? + name = "js_{}".format(FEM_name_mangling(element)) + name = restricted_name(name, restriction) + evaluate_reference_gradient(element, name, restriction) + + return prim.Variable(name) diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 49adf32b0f8f668df7bf297f56c469f41edaf727..8e7f48105cc35643fe6b2e4de155eee42dd0e043 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -79,6 +79,7 @@ def name_test_function_contribution(test): return restricted_name("contrib_{}phi_{}".format(grad, str(count)), test.restriction) +# palpo TODO move somewhere else from pymbolic.mapper import IdentityMapper class IndexReplaceMapper(IdentityMapper): """Replace indices in pymbolic expression using a dictionary lookup""" diff --git a/test/sumfact/poisson/poisson_dg_only_volume.mini b/test/sumfact/poisson/poisson_dg_only_volume.mini index 737d576bc7fc90a4304f2d91c76034a48e09b534..c0153667d89981fe2310cffa23beef6d34e4079a 100644 --- a/test/sumfact/poisson/poisson_dg_only_volume.mini +++ b/test/sumfact/poisson/poisson_dg_only_volume.mini @@ -6,6 +6,7 @@ __sumfact_suffix = normal, sumfact | expand sumf cells = 1 1 extension = 1. 1. +printresidual = 1 printmatrix = 1 [wrapper.vtkcompare]