Skip to content
Snippets Groups Projects
Commit e4467b6f authored by Dominic Kempf's avatar Dominic Kempf
Browse files

fixup

parent 0eb8cbba
No related branches found
No related tags found
No related merge requests found
...@@ -33,30 +33,14 @@ import numpy as np ...@@ -33,30 +33,14 @@ import numpy as np
import pymbolic.primitives as prim import pymbolic.primitives as prim
@generator_factory(item_tags=("sumfactkernel",), def realize_sum_factorization_kernel(sf, **kwargs):
context_tags=("kernel",), insn_dep = kwargs.pop('insn_dep', frozenset())
cache_key_generator=lambda s, **kw: s.cache_key) if not get_global_context_value("dry_run", False):
def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, direct_input=None, direct_output=None): insn_dep = _realize_input(sf, insn_dep)
# Unify the insn_dep parameter to be a frozenset return _realize_sum_factorization_kernel(sf, insn_dep=insn_dep, **kwargs)
if isinstance(insn_dep, str):
insn_dep = frozenset({insn_dep})
assert isinstance(insn_dep, frozenset)
# Get the vectorization info. During dry run, this is a now op
# sf = attach_vectorization_info(sf)
if get_global_context_value("dry_run", False):
# During the dry run, we just return the kernel as passed into this
# function. After the dry run, it can be used to attach information
# about vectorization.
return sf, insn_dep
# Get the instruction dependencies of the sumfact kernel. This variable will be
# updated throughout this function.
insn_dep = insn_dep.union(sf.insn_dep)
# Define some helper constructs that make it easier to write generic code later
vecindex = () if sf.index is None else (sf.index,)
def _realize_input(sf, insn_dep):
# Set up the input for stage 1 # Set up the input for stage 1
if sf.stage == 1 and not get_option("fastdg"): if sf.stage == 1 and not get_option("fastdg"):
assert sf.coeff_func assert sf.coeff_func
...@@ -71,10 +55,30 @@ def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, di ...@@ -71,10 +55,30 @@ def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, di
basisiname = sumfact_iname(name_lfs_bound(lfs), "basis") basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
container = sf.coeff_func(sf.restriction) container = sf.coeff_func(sf.restriction)
coeff = pymbolic_coefficient(container, lfs, basisiname) coeff = pymbolic_coefficient(container, lfs, basisiname)
vecindex = () if sf.index is None else (sf.index,)
assignee = prim.Subscript(prim.Variable(input_setup), (prim.Variable(basisiname),) + vecindex) assignee = prim.Subscript(prim.Variable(input_setup), (prim.Variable(basisiname),) + vecindex)
insn_dep = instruction(assignee=assignee, insn_dep = instruction(assignee=assignee,
expression=coeff, expression=coeff,
depends_on = insn_dep,
) )
return insn_dep
@generator_factory(item_tags=("sumfactkernel",),
context_tags=("kernel",),
cache_key_generator=lambda s, **kw: s.cache_key)
def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, direct_input=None, direct_output=None):
# Get the vectorization info. During dry run, this is a now op
# sf = attach_vectorization_info(sf)
if get_global_context_value("dry_run", False):
# During the dry run, we just return the kernel as passed into this
# function. After the dry run, it can be used to attach information
# about vectorization.
return sf, insn_dep
# Get the instruction dependencies of the sumfact kernel. This variable will be
# updated throughout this function.
insn_dep = insn_dep.union(sf.insn_dep)
# Prepare some dim_tags/shapes for later use # Prepare some dim_tags/shapes for later use
ftags = ",".join(["f"] * sf.length) ftags = ",".join(["f"] * sf.length)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment