Skip to content
Snippets Groups Projects
Commit 63886624 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix FastDG

parent a6fcd9ad
No related branches found
No related tags found
No related merge requests found
...@@ -110,6 +110,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): ...@@ -110,6 +110,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
tags=frozenset({"sumfact_stage{}".format(sf.stage)}), tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
) )
@property
def direct_input(self):
if get_option("fastdg"):
return self.coeff_func(self.restriction)
else:
return None
def _basis_functions_per_direction(element, component): def _basis_functions_per_direction(element, component):
"""Number of basis functions per direction of a given component of an element""" """Number of basis functions per direction of a given component of an element"""
......
...@@ -57,10 +57,10 @@ def _realize_sum_factorization_kernel(sf): ...@@ -57,10 +57,10 @@ def _realize_sum_factorization_kernel(sf):
), ),
})) }))
# Set up the input for stage 1 direct_input = sf.input.direct_input
if sf.stage == 1 and not get_option("fastdg"):
assert sf.input
# Set up the input for stage 1
if direct_input is None:
if sf.vectorized: if sf.vectorized:
for i, inputsf in enumerate(sf.kernels): for i, inputsf in enumerate(sf.kernels):
inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep)) inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep))
...@@ -69,11 +69,6 @@ def _realize_sum_factorization_kernel(sf): ...@@ -69,11 +69,6 @@ def _realize_sum_factorization_kernel(sf):
insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))})) insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))}))
# Construct the direct_input for the FastDG case
direct_input = None
if get_option('fastdg') and sf.stage == 1:
direct_input = sf.input.coeff_func(sf.input.restriction)
direct_output = None direct_output = None
if get_option('fastdg') and sf.stage == 3: if get_option('fastdg') and sf.stage == 3:
direct_output = sf.accumvar + ".data()" direct_output = sf.accumvar + ".data()"
......
...@@ -16,8 +16,11 @@ import inspect ...@@ -16,8 +16,11 @@ import inspect
class SumfactKernelInputBase(object): class SumfactKernelInputBase(object):
@property @property
def flat_shape(self): def direct_input(self):
return False return None
def realize(self, sf, i, dep):
pass
class SumfactKernelBase(object): class SumfactKernelBase(object):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment