From a36d8de4bc41b14645a05d01edd2ce31e2241b3e Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Thu, 4 Apr 2019 09:59:29 +0200
Subject: [PATCH] fix alias preamble

---
 .../codegen/blockstructured/accumulation.py   | 25 ++-------------
 .../dune/codegen/blockstructured/argument.py  | 25 ++-------------
 python/dune/codegen/blockstructured/tools.py  | 32 ++++++++++++++++++-
 3 files changed, 37 insertions(+), 45 deletions(-)

diff --git a/python/dune/codegen/blockstructured/accumulation.py b/python/dune/codegen/blockstructured/accumulation.py
index 1f26d3ca..4c0317c2 100644
--- a/python/dune/codegen/blockstructured/accumulation.py
+++ b/python/dune/codegen/blockstructured/accumulation.py
@@ -1,12 +1,12 @@
-from dune.codegen.blockstructured.tools import sub_element_inames
-from dune.codegen.generation import accumulation_mixin, instruction, preamble
+from dune.codegen.blockstructured.tools import sub_element_inames, name_accumulation_alias
+from dune.codegen.generation import accumulation_mixin, instruction
 from dune.codegen.loopy.target import dtype_floatingpoint
 from dune.codegen.options import get_form_option
 from dune.codegen.pdelab.geometry import world_dimension, name_intersection_geometry_wrapper
 from dune.codegen.pdelab.localoperator import determine_accumulation_space, GenericAccumulationMixin
 from dune.codegen.pdelab.argument import name_accumulation_variable
 from dune.codegen.pdelab.localoperator import boundary_predicates
-from dune.codegen.generation.loopy import function_mangler, globalarg, temporary_variable
+from dune.codegen.generation.loopy import function_mangler, temporary_variable
 import loopy as lp
 import pymbolic.primitives as prim
 
@@ -22,25 +22,6 @@ class BlockStructuredAccumulationMixin(GenericAccumulationMixin):
             return generate_accumulation_instruction(expr, self)
 
 
-def name_accumulation_alias(container, accumspace):
-    name = container + "_" + accumspace.lfs.name + "_alias"
-    name_tail = container + "_" + accumspace.lfs.name + "_alias_tail"
-    k = get_form_option("number_of_blocks")
-    p = accumspace.element.degree()
-
-    @preamble
-    def _add_alias_insn(name):
-        dim = world_dimension()
-        element_stride = tuple(p * (p * k + 1)**i for i in range(0, dim))
-        index_stride = tuple((p * k + 1)**i for i in range(0, dim))
-        globalarg(name, shape=(k,) * dim + (p + 1,) * dim, strides=element_stride + index_stride, managed=True)
-        return "auto {} = &{}.container()({},0);".format(name, container, accumspace.lfs.name)
-
-    _add_alias_insn(name)
-    _add_alias_insn(name_tail)
-    return name
-
-
 @function_mangler
 def residual_weight_mangler(knl, func, arg_dtypes):
     if isinstance(func, str) and func.endswith('.weight'):
diff --git a/python/dune/codegen/blockstructured/argument.py b/python/dune/codegen/blockstructured/argument.py
index 0cc7bd17..a96630cf 100644
--- a/python/dune/codegen/blockstructured/argument.py
+++ b/python/dune/codegen/blockstructured/argument.py
@@ -1,30 +1,11 @@
-from dune.codegen.generation import (kernel_cached,
-                                     valuearg, globalarg, preamble)
+from dune.codegen.generation import kernel_cached, valuearg
 from dune.codegen.options import get_form_option
 from dune.codegen.pdelab.argument import CoefficientAccess
-from dune.codegen.blockstructured.tools import micro_index_to_macro_index, sub_element_inames
-from dune.codegen.pdelab.geometry import world_dimension
+from dune.codegen.blockstructured.tools import micro_index_to_macro_index, sub_element_inames, name_container_alias
 from loopy.types import NumpyType
 import pymbolic.primitives as prim
 
 
-def name_alias(container, lfs, element):
-    name = container + "_" + lfs.name + "_alias"
-
-    @preamble
-    def _add_alias_insn(name):
-        k = get_form_option("number_of_blocks")
-        p = element.degree()
-        dim = world_dimension()
-        element_stride = tuple(p * (p * k + 1)**i for i in range(0, dim))
-        index_stride = tuple((p * k + 1)**i for i in range(0, dim))
-        globalarg(name, shape=(k,) * dim + (p + 1,) * dim, strides=element_stride + index_stride, managed=True)
-        return "const auto {} = &{}({},0);".format(name, container, lfs.name)
-
-    _add_alias_insn(name)
-    return name
-
-
 # TODO remove the need for element
 @kernel_cached
 def pymbolic_coefficient(container, lfs, element, index):
@@ -39,7 +20,7 @@ def pymbolic_coefficient(container, lfs, element, index):
     # use higher order FEM index instead of Q1 index
     if get_form_option("vectorization_blockstructured"):
         subelem_inames = sub_element_inames()
-        coeff_alias = name_alias(container, lfs, element)
+        coeff_alias = name_container_alias(container, lfs, element)
         return prim.Subscript(prim.Variable(coeff_alias), tuple(prim.Variable(i) for i in subelem_inames + index))
     else:
         return prim.Call(CoefficientAccess(container), (lfs, micro_index_to_macro_index(element, index),))
diff --git a/python/dune/codegen/blockstructured/tools.py b/python/dune/codegen/blockstructured/tools.py
index 802b819f..13cd2b84 100644
--- a/python/dune/codegen/blockstructured/tools.py
+++ b/python/dune/codegen/blockstructured/tools.py
@@ -2,7 +2,7 @@ from dune.codegen.tools import get_pymbolic_basename
 from dune.codegen.generation import (iname,
                                      domain,
                                      temporary_variable,
-                                     instruction)
+                                     instruction, globalarg, preamble)
 from dune.codegen.pdelab.geometry import world_dimension
 from dune.codegen.options import get_form_option
 import pymbolic.primitives as prim
@@ -74,3 +74,33 @@ def name_point_in_macro(point_in_micro, visitor):
     name = get_pymbolic_basename(point_in_micro) + "_macro"
     define_point_in_macro(name, point_in_micro, visitor)
     return name
+
+
+@preamble
+def define_container_alias(name, container, lfs, element, is_const):
+    k = get_form_option("number_of_blocks")
+    p = element.degree()
+    dim = world_dimension()
+    element_stride = tuple(p * (p * k + 1)**i for i in range(0, dim))
+    index_stride = tuple((p * k + 1)**i for i in range(0, dim))
+    globalarg(name, shape=(k,) * dim + (p + 1,) * dim, strides=element_stride + index_stride, managed=True)
+    if is_const:
+        return "const auto {} = &{}({},0);".format(name, container, lfs.name)
+    else:
+        return "auto {} = &{}.container()({},0);".format(name, container, lfs.name)
+
+
+def name_accumulation_alias(container, accumspace):
+    name = container + "_" + accumspace.lfs.name + "_alias"
+    name_tail = container + "_" + accumspace.lfs.name + "_alias_tail"
+
+    define_container_alias(name, container, accumspace.lfs, accumspace.element, is_const=False)
+    define_container_alias(name_tail, container, accumspace.lfs, accumspace.element, is_const=False)
+    return name
+
+
+def name_container_alias(container, lfs, element):
+    name = container + "_" + lfs.name + "_alias"
+
+    define_container_alias(name, container, lfs, element, is_const=True)
+    return name
\ No newline at end of file
-- 
GitLab