Skip to content
Snippets Groups Projects
Commit 401b002d authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix vectorized access to class member precomputations

parent 1529cb73
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ is filled and then does vector computations """ ...@@ -3,6 +3,7 @@ is filled and then does vector computations """
from dune.perftool.generation import (function_mangler, from dune.perftool.generation import (function_mangler,
include_file, include_file,
loopy_class_member,
) )
from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size
from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view, from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view,
...@@ -11,7 +12,7 @@ from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_v ...@@ -11,7 +12,7 @@ from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_v
) )
from dune.perftool.loopy.symbolic import substitute from dune.perftool.loopy.symbolic import substitute
from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag, ceildiv
from dune.perftool.options import get_option from dune.perftool.options import get_option
from loopy.kernel.creation import parse_domains from loopy.kernel.creation import parse_domains
...@@ -19,6 +20,7 @@ from loopy.symbolic import pw_aff_to_expr ...@@ -19,6 +20,7 @@ from loopy.symbolic import pw_aff_to_expr
from loopy.match import Tagged from loopy.match import Tagged
from loopy.symbolic import DependencyMapper from loopy.symbolic import DependencyMapper
from pytools import product
import pymbolic.primitives as prim import pymbolic.primitives as prim
import loopy as lp import loopy as lp
...@@ -234,6 +236,23 @@ def collect_vector_data_rotate(knl): ...@@ -234,6 +236,23 @@ def collect_vector_data_rotate(knl):
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(1), prim.Variable(new_iname)), (vector_indices.get(1), prim.Variable(new_iname)),
) )
elif quantity in [a.name for a in knl.args]:
arg, = [a for a in knl.args if a.name == quantity]
tags = set(get_pymbolic_tag(expr) for expr in quantity_exprs)
if tags and tags.pop() == "operator_precomputed":
expr, = quantity_exprs
shape=(ceildiv(product(s for s in arg.shape), vec_size), vec_size)
name = loopy_class_member(quantity,
shape=shape,
dim_tags="c,vec",
potentially_vectorized=True,
classtag="operator",
dtype=np.float64,
)
knl = knl.copy(args=knl.args + [lp.GlobalArg(name, shape=shape, dim_tags="c,vec", dtype=np.float64)])
replacemap_vec[expr] = prim.Subscript(prim.Variable(name),
(vector_indices.get(1), prim.Variable(new_iname)),
)
new_insns = [i.copy(expression=substitute(i.expression, replacemap_arr)) for i in new_insns] new_insns = [i.copy(expression=substitute(i.expression, replacemap_arr)) for i in new_insns]
......
...@@ -29,6 +29,7 @@ from pymbolic.primitives import (Call, ...@@ -29,6 +29,7 @@ from pymbolic.primitives import (Call,
) )
import pymbolic.primitives as prim import pymbolic.primitives as prim
import loopy as lp
import numpy as np import numpy as np
...@@ -146,7 +147,7 @@ def quadrature_weight(): ...@@ -146,7 +147,7 @@ def quadrature_weight():
kernel="operator", kernel="operator",
) )
return prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in quadrature_inames())) return prim.Subscript(lp.symbolic.TaggedVariable(name, "operator_precomputed"), tuple(prim.Variable(i) for i in quadrature_inames()))
def define_quadrature_position(name): def define_quadrature_position(name):
......
...@@ -25,7 +25,7 @@ from dune.perftool.pdelab.localoperator import (name_domain_field, ...@@ -25,7 +25,7 @@ from dune.perftool.pdelab.localoperator import (name_domain_field,
lop_template_range_field, lop_template_range_field,
) )
from dune.perftool.pdelab.quadrature import quadrature_order from dune.perftool.pdelab.quadrature import quadrature_order
from dune.perftool.tools import maybe_wrap_subscript from dune.perftool.tools import maybe_wrap_subscript, ceildiv
from loopy import CallMangleInfo from loopy import CallMangleInfo
from loopy.symbolic import FunctionIdentifier from loopy.symbolic import FunctionIdentifier
from loopy.types import NumpyType from loopy.types import NumpyType
...@@ -37,10 +37,6 @@ import loopy as lp ...@@ -37,10 +37,6 @@ import loopy as lp
import numpy as np import numpy as np
def ceildiv(a, b):
return -(-a // b)
class BasisTabulationMatrixBase(object): class BasisTabulationMatrixBase(object):
pass pass
......
...@@ -59,3 +59,7 @@ def get_pymbolic_tag(expr): ...@@ -59,3 +59,7 @@ def get_pymbolic_tag(expr):
return get_pymbolic_tag(expr.aggregate) return get_pymbolic_tag(expr.aggregate)
raise NotImplementedError("Cannot determine tag on {}".format(expr)) raise NotImplementedError("Cannot determine tag on {}".format(expr))
def ceildiv(a, b):
return -(-a // b)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment