Skip to content
Snippets Groups Projects
Commit e47c7a01 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Update with bounds in continue statement

parent 40d9afad
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,7 @@ def get_temporary_name():
@generator_factory(item_tags=("temporary",), cache_key_generator=lambda n, **kw: n)
def temporary_variable(name, **kwargs):
from dune.perftool.loopy.temporary import DuneTemporaryVariable
return DuneTemporaryVariable(name, **kwargs)
return DuneTemporaryVariable(name, scope=loopy.temp_var_scope.LOCAL, **kwargs)
# Now define generators for instructions. To ease dependency handling of instructions
......
......@@ -3,6 +3,8 @@
from dune.perftool.loopy.vcl import VCLLoad, VCLStore
from dune.perftool.tools import get_pymbolic_basename
from loopy.symbolic import pw_aff_to_expr
from pymbolic.mapper.dependency import DependencyMapper
from pymbolic.mapper.substitutor import substitute
......@@ -97,18 +99,9 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
# Assert some assumptions on the instructions
#
# All instructions within the given inames are either to be vectorized
# or write a dependent quantity
for insn in other_insns + dep_insns:
len(insn.within_inames.intersection(inames) == 0)
# An instruction occurs in but one of these groups:
assert len(set(insns + write_insns + dep_insns + other_insns)) == len(insns + write_insns + dep_insns + other_insns)
# All the target and write instructions are Assignments
for insn in insns + write_insns:
assert isinstance(insn, lp.Assignment)
# Analyse the inames of the given instructions and identify inames
# that they all have in common. Those inames will also be iname dependencies
# of inserted instruction.
......@@ -122,13 +115,12 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
# Insert a flat consecutive counter 'total_index'
temporaries['total_index'] = lp.TemporaryVariable('total_index', # name
dtype=np.int32,
scope=lp.temp_var_scope.LOCAL,
)
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
depends_on=frozenset(i.id for i in dep_insns),
depends_on_is_final=True,
id="assign_total_index",
))
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
......@@ -143,13 +135,12 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
# Insert a rotating index, that counts 0 , .. , vecsize - 1
temporaries['rotate_index'] = lp.TemporaryVariable('rotate_index', # name
dtype=np.int32,
scope=lp.temp_var_scope.LOCAL,
)
new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
depends_on=frozenset(i.id for i in dep_insns),
depends_on_is_final=True,
id="assign_rotate_index",
))
new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee
......@@ -177,6 +168,7 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
shape=(vec_size,),
dim_tags="c",
base_storage=quantity + '_base_storage',
scope=lp.temp_var_scope.LOCAL,
)
vecname = quantity + '_buffered_vec'
......@@ -185,20 +177,32 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
shape=(vec_size,),
dim_tags="vec",
base_storage=quantity + '_base_storage',
scope=lp.temp_var_scope.LOCAL,
)
replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
replacemap_vec[quantity_expr] = prim.Variable(vecname)
for insn in write_insns:
new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
)
)
if isinstance(insn, lp.Assignment):
new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
)
)
if isinstance(insn, lp.CInstruction):
# TODO: What do we do about CInstructions?
# Example: detjac = ...
new_insns.append(insn)
# Determine the condition for the continue statement
upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound)
rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0)
check = prim.LogicalAnd((rotate_check, total_check))
# Insert the 'continue' statement
new_insns.append(lp.CInstruction((), # iname exprs that the code needs access to
"continue;", # the code
predicates=frozenset({"rotate_index != 0", "blubb"}),
predicates=frozenset({check}),
depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset([i.id for i in write_insns])),
depends_on_is_final=True,
within_inames=common_inames.union(inames),
......@@ -234,6 +238,7 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
shape=(vec_size,),
dim_tags="vec",
base_storage="{}_base".format(basename),
scope=lp.temp_var_scope.LOCAL,
)
new_insns.append(insn.copy(assignee=prim.Variable(name),
expression=substitute(insn.expression, variable_assignments=replacemap_vec),
......
"""
Our extensions to the loopy type system
"""
from dune.perftool.generation import function_mangler
from dune.perftool.generation import (function_mangler,
include_file,
)
from loopy.symbolic import FunctionIdentifier
from loopy.types import NumpyType
......@@ -82,4 +84,5 @@ class VCLStore(FunctionIdentifier):
@function_mangler
def vcl_mangler(target, func, dtypes):
if isinstance(func, (VCLLoad, VCLStore)):
include_file("dune/perftool/vectorclass/vectorclass.h", filetag="operatorfile")
return CallMangleInfo(func.name, (), (NumpyType(np.int32),))
......@@ -200,6 +200,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
transform(nest_quadrature_loops, visitor.inames)
#TODO!!!
from dune.perftool.loopy.collectvector import collect_vector_data
transform(collect_vector_data, [contrib_dep], quadrature_inames())
def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional_inames=frozenset({})):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment