From 401b002da159f248996a5755c4dc32ae2f6afbca Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 26 Apr 2017 13:29:24 +0200 Subject: [PATCH] Fix vectorized access to class member precomputations --- .../loopy/transformations/collect_rotate.py | 21 ++++++++++++++++++- python/dune/perftool/sumfact/quadrature.py | 3 ++- python/dune/perftool/sumfact/tabulation.py | 6 +----- python/dune/perftool/tools.py | 4 ++++ 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index b64503b4..20bc4130 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -3,6 +3,7 @@ is filled and then does vector computations """ from dune.perftool.generation import (function_mangler, include_file, + loopy_class_member, ) from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view, @@ -11,7 +12,7 @@ from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_v ) from dune.perftool.loopy.symbolic import substitute from dune.perftool.sumfact.quadrature import quadrature_inames -from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag +from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag, ceildiv from dune.perftool.options import get_option from loopy.kernel.creation import parse_domains @@ -19,6 +20,7 @@ from loopy.symbolic import pw_aff_to_expr from loopy.match import Tagged from loopy.symbolic import DependencyMapper +from pytools import product import pymbolic.primitives as prim import loopy as lp @@ -234,6 +236,23 @@ def collect_vector_data_rotate(knl): replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (vector_indices.get(1), prim.Variable(new_iname)), ) + elif quantity in [a.name for a in knl.args]: + arg, = [a for a in knl.args if a.name == quantity] + tags = set(get_pymbolic_tag(expr) for expr in quantity_exprs) + if tags and tags.pop() == "operator_precomputed": + expr, = quantity_exprs + shape=(ceildiv(product(s for s in arg.shape), vec_size), vec_size) + name = loopy_class_member(quantity, + shape=shape, + dim_tags="c,vec", + potentially_vectorized=True, + classtag="operator", + dtype=np.float64, + ) + knl = knl.copy(args=knl.args + [lp.GlobalArg(name, shape=shape, dim_tags="c,vec", dtype=np.float64)]) + replacemap_vec[expr] = prim.Subscript(prim.Variable(name), + (vector_indices.get(1), prim.Variable(new_iname)), + ) new_insns = [i.copy(expression=substitute(i.expression, replacemap_arr)) for i in new_insns] diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py index e3d58fba..0cff5516 100644 --- a/python/dune/perftool/sumfact/quadrature.py +++ b/python/dune/perftool/sumfact/quadrature.py @@ -29,6 +29,7 @@ from pymbolic.primitives import (Call, ) import pymbolic.primitives as prim +import loopy as lp import numpy as np @@ -146,7 +147,7 @@ def quadrature_weight(): kernel="operator", ) - return prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in quadrature_inames())) + return prim.Subscript(lp.symbolic.TaggedVariable(name, "operator_precomputed"), tuple(prim.Variable(i) for i in quadrature_inames())) def define_quadrature_position(name): diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py index 9ccd5891..c0405c6a 100644 --- a/python/dune/perftool/sumfact/tabulation.py +++ b/python/dune/perftool/sumfact/tabulation.py @@ -25,7 +25,7 @@ from dune.perftool.pdelab.localoperator import (name_domain_field, lop_template_range_field, ) from dune.perftool.pdelab.quadrature import quadrature_order -from dune.perftool.tools import maybe_wrap_subscript +from dune.perftool.tools import maybe_wrap_subscript, ceildiv from loopy import CallMangleInfo from loopy.symbolic import FunctionIdentifier from loopy.types import NumpyType @@ -37,10 +37,6 @@ import loopy as lp import numpy as np -def ceildiv(a, b): - return -(-a // b) - - class BasisTabulationMatrixBase(object): pass diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py index e6259e44..a0dfaf29 100644 --- a/python/dune/perftool/tools.py +++ b/python/dune/perftool/tools.py @@ -59,3 +59,7 @@ def get_pymbolic_tag(expr): return get_pymbolic_tag(expr.aggregate) raise NotImplementedError("Cannot determine tag on {}".format(expr)) + + +def ceildiv(a, b): + return -(-a // b) -- GitLab