From f57ec6b269498220d2e37ac26785760cf0ec867c Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 12 Oct 2016 15:38:36 +0200 Subject: [PATCH] Fix some stokes kernels --- python/dune/perftool/pdelab/localoperator.py | 12 +++++++----- .../transformations/extract_accumulation_terms.py | 11 +++++++---- .../ufl/transformations/identitypropagation.py | 11 ++++++----- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 068cae0e..71f552cc 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -353,11 +353,10 @@ def boundary_predicates(expr, measure, subdomain_id): @iname -def grad_iname(ma): +def grad_iname(index, dim): from dune.perftool.pdelab import name_index - from ufl.domain import find_geometric_dimension - name = name_index(ma.index) - domain(name, find_geometric_dimension(ma.expr)) + name = name_index(index) + domain(name, dim) return name @@ -368,7 +367,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # If this is a gradient, we generate an iname additional_inames = frozenset({}) if accterm.argument.index: - additional_inames = frozenset({grad_iname(accterm.argument)}) + from ufl.domain import find_geometric_dimension + dim = find_geometric_dimension(accterm.argument.expr) + for i in accterm.argument.index._indices: + additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)})) # It may happen that an entire accumulation term vanishes. We do nothing in that case if pymbolic_expr == 0: diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index 7fdf46b3..3d53f9c2 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -44,10 +44,13 @@ def split_into_accumulation_terms(expr): replace_expr = replace_expression(expr, replacemap=replacement) # 2) Cut the test function itself from the expression - if test_arg.reference_grad: - newi = indices(1) - replacement = {test_arg.expr: Indexed(Identity(2), MultiIndex(newi + test_arg.index._indices))} - test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={test_arg.index[0]: newi[0]})) + if test_arg.index: + newi = indices(len(test_arg.index)) + identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices)) + from dune.perftool.ufl.flatoperators import construct_binary_operator + from ufl.classes import Product + replacement = {test_arg.expr: construct_binary_operator(identities, Product)} + test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={i: j for i, j in zip(test_arg.index._indices, newi)})) else: replacement = {test_arg.expr: IntValue(1)} replace_expr = replace_expression(replace_expr, replacemap=replacement) diff --git a/python/dune/perftool/ufl/transformations/identitypropagation.py b/python/dune/perftool/ufl/transformations/identitypropagation.py index 2e7bbab5..aeb21a10 100644 --- a/python/dune/perftool/ufl/transformations/identitypropagation.py +++ b/python/dune/perftool/ufl/transformations/identitypropagation.py @@ -14,7 +14,7 @@ class GetIndexMap(MultiFunction): call = MultiFunction.__call__ def __call__(self, o): - self.free_index = Index(o.ufl_free_indices[0]) + self.free_indices = frozenset(Index(i) for i in o.ufl_free_indices) self.replacemap = {} self.call(o) return self.replacemap @@ -26,10 +26,11 @@ class GetIndexMap(MultiFunction): def indexed(self, o): op, i = o.ufl_operands if isinstance(op, Identity): - assert(len(i) == 2) - assert(self.free_index in i) - ind, = set(i) - {self.free_index} - self.replacemap[ind] = self.free_index + free_index = self.free_indices.intersection(frozenset(i)) + assert(len(free_index) == 1) + ind, = frozenset(i) - free_index + free_index, = free_index + self.replacemap[ind] = free_index else: self.call(op) -- GitLab