diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 794c3467d286374b2ba3894c2ca40c43792f1dc5..b346a2f9ab1c529f0b1c674f241f7b799f3180c8 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -29,7 +29,9 @@ from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.switch import (get_facedir, get_facemod, ) -from dune.perftool.pdelab.geometry import world_dimension +from dune.perftool.pdelab.geometry import (local_dimension, + world_dimension, + ) from dune.perftool.loopy.buffer import initialize_buffer from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.restriction import restricted_name @@ -206,7 +208,7 @@ def evaluate_basis(element, name, restriction): ) # Add the missing direction on facedirs by evaluating at either 0 or 1 - if facedir: + if facedir is not None: facemod = get_facemod(restriction) from dune.perftool.sumfact.amatrix import PolynomialLookup, name_polynomials prod = prod + (prim.Call(PolynomialLookup(name_polynomials(), False), @@ -238,20 +240,32 @@ def evaluate_reference_gradient(element, name, restriction): temporary_variable(name, shape=(dim,)) quad_inames = quadrature_inames() inames = lfs_inames(element, restriction) - assert(len(quad_inames) == len(inames)) - - # Matrices for sumfactorization - theta = name_theta() - dtheta = name_theta(derivative=True) - - for i in range(dim): - calls = [prim.Subscript(prim.Variable(theta), (prim.Variable(m), prim.Variable(n))) - for (m, n) in zip(quad_inames, inames)] - calls[i] = prim.Subscript(prim.Variable(dtheta), (prim.Variable(quad_inames[i]), prim.Variable(inames[i]))) - calls = tuple(calls) + facedir = get_facedir(restriction) - assignee = prim.Subscript(prim.Variable(name), (i,)) - expression = prim.Product(calls) + # Map the direction to a quadrature iname + quadinamemapping = {} + i = 0 + for d in range(local_dimension()): + if d == facedir: + i = i+1 + quadinamemapping[i] = quad_inames[d] + i = i+1 + + for d in range(dim): + prod = [] + for i in range(dim): + if i != facedir: + prod.append(prim.Subscript(prim.Variable(name_theta(derivative=d == i)), + (prim.Variable(quadinamemapping[i]), prim.Variable(inames[i])) + )) + if facedir is not None: + facemod = get_facemod(restriction) + from dune.perftool.sumfact.amatrix import PolynomialLookup, name_polynomials + prod.append(prim.Call(PolynomialLookup(name_polynomials(), facedir==d), + (prim.Variable(inames[facedir]), facemod)),) + + assignee = prim.Subscript(prim.Variable(name), (d,)) + expression = prim.Product(tuple(prod)) instruction(assignee=assignee, expression=expression, diff --git a/test/sumfact/poisson/poisson_dg.mini b/test/sumfact/poisson/poisson_dg.mini index c59d3c054a4f4528a8ec677e1826fbd9609e1b26..eeca1791a3d2b3f65abf7e7c41d010f4f327e7e4 100644 --- a/test/sumfact/poisson/poisson_dg.mini +++ b/test/sumfact/poisson/poisson_dg.mini @@ -1,6 +1,5 @@ __name = sumfact_poisson_dg_{__exec_suffix} __exec_suffix = numdiff, symdiff | expand num -{__exec_suffix} == symdiff | exclude cells = 1 1 extension = 1. 1.