From 401b002da159f248996a5755c4dc32ae2f6afbca Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 26 Apr 2017 13:29:24 +0200
Subject: [PATCH] Fix vectorized access to class member precomputations

---
 .../loopy/transformations/collect_rotate.py   | 21 ++++++++++++++++++-
 python/dune/perftool/sumfact/quadrature.py    |  3 ++-
 python/dune/perftool/sumfact/tabulation.py    |  6 +-----
 python/dune/perftool/tools.py                 |  4 ++++
 4 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py
index b64503b4..20bc4130 100644
--- a/python/dune/perftool/loopy/transformations/collect_rotate.py
+++ b/python/dune/perftool/loopy/transformations/collect_rotate.py
@@ -3,6 +3,7 @@ is filled and then does vector computations """
 
 from dune.perftool.generation import (function_mangler,
                                       include_file,
+                                      loopy_class_member,
                                       )
 from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size
 from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view,
@@ -11,7 +12,7 @@ from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_v
                                                             )
 from dune.perftool.loopy.symbolic import substitute
 from dune.perftool.sumfact.quadrature import quadrature_inames
-from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag
+from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag, ceildiv
 from dune.perftool.options import get_option
 
 from loopy.kernel.creation import parse_domains
@@ -19,6 +20,7 @@ from loopy.symbolic import pw_aff_to_expr
 from loopy.match import Tagged
 
 from loopy.symbolic import DependencyMapper
+from pytools import product
 
 import pymbolic.primitives as prim
 import loopy as lp
@@ -234,6 +236,23 @@ def collect_vector_data_rotate(knl):
                 replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
                                                       (vector_indices.get(1), prim.Variable(new_iname)),
                                                       )
+        elif quantity in [a.name for a in knl.args]:
+            arg, = [a for a in knl.args if a.name == quantity]
+            tags = set(get_pymbolic_tag(expr) for expr in quantity_exprs)
+            if tags and tags.pop() == "operator_precomputed":
+                expr, = quantity_exprs
+                shape=(ceildiv(product(s for s in arg.shape), vec_size), vec_size)
+                name = loopy_class_member(quantity,
+                                          shape=shape,
+                                          dim_tags="c,vec",
+                                          potentially_vectorized=True,
+                                          classtag="operator",
+                                          dtype=np.float64,
+                                          )
+                knl = knl.copy(args=knl.args + [lp.GlobalArg(name, shape=shape, dim_tags="c,vec", dtype=np.float64)])
+                replacemap_vec[expr] = prim.Subscript(prim.Variable(name),
+                                                      (vector_indices.get(1), prim.Variable(new_iname)),
+                                                      )
 
     new_insns = [i.copy(expression=substitute(i.expression, replacemap_arr)) for i in new_insns]
 
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index e3d58fba..0cff5516 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -29,6 +29,7 @@ from pymbolic.primitives import (Call,
                                  )
 
 import pymbolic.primitives as prim
+import loopy as lp
 import numpy as np
 
 
@@ -146,7 +147,7 @@ def quadrature_weight():
                 kernel="operator",
                 )
 
-    return prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in quadrature_inames()))
+    return prim.Subscript(lp.symbolic.TaggedVariable(name, "operator_precomputed"), tuple(prim.Variable(i) for i in quadrature_inames()))
 
 
 def define_quadrature_position(name):
diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py
index 9ccd5891..c0405c6a 100644
--- a/python/dune/perftool/sumfact/tabulation.py
+++ b/python/dune/perftool/sumfact/tabulation.py
@@ -25,7 +25,7 @@ from dune.perftool.pdelab.localoperator import (name_domain_field,
                                                 lop_template_range_field,
                                                 )
 from dune.perftool.pdelab.quadrature import quadrature_order
-from dune.perftool.tools import maybe_wrap_subscript
+from dune.perftool.tools import maybe_wrap_subscript, ceildiv
 from loopy import CallMangleInfo
 from loopy.symbolic import FunctionIdentifier
 from loopy.types import NumpyType
@@ -37,10 +37,6 @@ import loopy as lp
 import numpy as np
 
 
-def ceildiv(a, b):
-    return -(-a // b)
-
-
 class BasisTabulationMatrixBase(object):
     pass
 
diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py
index e6259e44..a0dfaf29 100644
--- a/python/dune/perftool/tools.py
+++ b/python/dune/perftool/tools.py
@@ -59,3 +59,7 @@ def get_pymbolic_tag(expr):
         return get_pymbolic_tag(expr.aggregate)
 
     raise NotImplementedError("Cannot determine tag on {}".format(expr))
+
+
+def ceildiv(a, b):
+    return -(-a // b)
-- 
GitLab