Skip to content
Snippets Groups Projects
Commit 896cb5e4 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Add a temporary solution for extraction of vectorized output properties

parent 7c6bbade
No related branches found
No related tags found
No related merge requests found
......@@ -193,37 +193,51 @@ def _realize_sum_factorization_kernel(sf):
# of the Sumfactorization into some global data structure.
if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3:
ft = get_global_context_value("form_type")
if sf.test_element_index is None:
direct_output = "{}_access".format(sf.accumvar)
# TODO This one will break super-hard!
if sf.vectorized:
test_element = sf.output.outputs[0].test_element
test_element_index = sf.output.outputs[0].test_element_index
trial_element = sf.output.outputs[0].trial_element
trial_element_index = sf.output.outputs[0].trial_element_index
accumvar = sf.output.outputs[0].accumvar
else:
direct_output = "{}_access_comp{}".format(sf.accumvar, sf.test_element_index)
test_element = sf.output.test_element
test_element_index = sf.output.test_element_index
trial_element = sf.output.trial_element
trial_element_index = sf.output.trial_element_index
accumvar = sf.output.accumvar
if test_element_index is None:
direct_output = "{}_access".format(accumvar)
else:
direct_output = "{}_access_comp{}".format(accumvar, test_element_index)
if ft == 'residual' or ft == 'jacobian_apply':
globalarg(direct_output,
shape=output_shape,
dim_tags=novec_ftags,
offset=_dof_offset(sf.test_element, sf.test_element_index),
offset=_dof_offset(test_element, test_element_index),
)
alias_data_array(direct_output, sf.accumvar)
alias_data_array(direct_output, accumvar)
assignee = prim.Subscript(prim.Variable(direct_output), output_inames)
else:
assert ft == 'jacobian'
direct_output = "{}x{}".format(direct_output, sf.trial_element_index)
rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element)))
element = sf.trial_element
direct_output = "{}x{}".format(direct_output, trial_element_index)
rowsize = sum(tuple(s for s in _local_sizes(trial_element)))
element = trial_element
if isinstance(element, MixedElement):
element = element.extract_component(sf.trial_element_index)[1]
element = element.extract_component(trial_element_index)[1]
other_shape = tuple(element.degree() + 1 for e in range(sf.length))
from pytools import product
manual_strides = tuple("stride:{}".format(rowsize * product(output_shape[:i])) for i in range(sf.length))
dim_tags = "{},{}".format(novec_ftags, ",".join(manual_strides))
globalarg(direct_output,
shape=other_shape + output_shape,
offset=rowsize * _dof_offset(sf.test_element, sf.test_element_index) + _dof_offset(sf.trial_element, sf.trial_element_index),
offset=rowsize * _dof_offset(test_element, test_element_index) + _dof_offset(trial_element, trial_element_index),
dim_tags=dim_tags,
)
alias_data_array(direct_output, sf.accumvar)
alias_data_array(direct_output, accumvar)
# TODO: It is at least questionnable, whether using the *order* of the inames in here
# for indexing is a good idea. Then again, it is hard to find an alternative.
_ansatz_inames = tuple(prim.Variable(i) for i in sf.within_inames)
......
......@@ -30,22 +30,6 @@ class SumfactKernelInputBase(object):
raise NotImplementedError
class SumfactKernelOutputBase(object):
@property
def direct_output_is_possible(self):
return False
@property
def within_inames(self):
return ()
def realize(self, sf, dep):
return dep
def realize_direct(self):
raise NotImplementedError
class VectorSumfactKernelInput(SumfactKernelInputBase):
def __init__(self, inputs):
assert(isinstance(inputs, tuple))
......@@ -88,6 +72,27 @@ class VectorSumfactKernelInput(SumfactKernelInputBase):
raise NotImplementedError("SIMD loads from scalars not implemented!")
class SumfactKernelOutputBase(object):
@property
def direct_output_is_possible(self):
return False
@property
def within_inames(self):
return ()
def realize(self, sf, dep):
return dep
def realize_direct(self):
raise NotImplementedError
class VectorSumfactKernelOutput(SumfactKernelOutputBase):
def __init__(self, outputs):
self.outputs = outputs
class SumfactKernelBase(object):
pass
......@@ -508,31 +513,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def within_inames(self):
return self.kernels[0].within_inames
@property
def test_element(self):
return self.kernels[0].test_element
@property
def test_element_index(self):
return self.kernels[0].test_element_index
@property
def trial_element(self):
return self.kernels[0].trial_element
@property
def trial_element_index(self):
return self.kernels[0].trial_element_index
@property
def predicates(self):
return self.kernels[0].predicates
@property
def accumvar(self):
assert len(set(k.accumvar for k in self.kernels)) == 1
return self.kernels[0].accumvar
@property
def transposed(self):
return self.kernels[0].transposed
......@@ -556,6 +540,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def input(self):
return VectorSumfactKernelInput(tuple(k.input for k in self.kernels))
@property
def output(self):
return VectorSumfactKernelOutput(tuple(k.output for k in self.kernels))
@property
def cache_key(self):
return (tuple(k.cache_key for k in self.kernels), self.buffer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment