From 896cb5e429cad84b1cec870cb815106c20abf6a4 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 15 Feb 2018 09:15:35 +0100
Subject: [PATCH] Add a temporary solution for extraction of vectorized output
 properties

---
 python/dune/perftool/sumfact/realization.py | 36 ++++++++----
 python/dune/perftool/sumfact/symbolic.py    | 62 +++++++++------------
 2 files changed, 50 insertions(+), 48 deletions(-)

diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index cc6704af..20949b20 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -193,37 +193,51 @@ def _realize_sum_factorization_kernel(sf):
         # of the Sumfactorization into some global data structure.
         if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3:
             ft = get_global_context_value("form_type")
-            if sf.test_element_index is None:
-                direct_output = "{}_access".format(sf.accumvar)
+            # TODO This one will break super-hard!
+            if sf.vectorized:
+                test_element = sf.output.outputs[0].test_element
+                test_element_index = sf.output.outputs[0].test_element_index
+                trial_element = sf.output.outputs[0].trial_element
+                trial_element_index = sf.output.outputs[0].trial_element_index
+                accumvar = sf.output.outputs[0].accumvar
             else:
-                direct_output = "{}_access_comp{}".format(sf.accumvar, sf.test_element_index)
+                test_element = sf.output.test_element
+                test_element_index = sf.output.test_element_index
+                trial_element = sf.output.trial_element
+                trial_element_index = sf.output.trial_element_index
+                accumvar = sf.output.accumvar
+
+            if test_element_index is None:
+                direct_output = "{}_access".format(accumvar)
+            else:
+                direct_output = "{}_access_comp{}".format(accumvar, test_element_index)
             if ft == 'residual' or ft == 'jacobian_apply':
                 globalarg(direct_output,
                           shape=output_shape,
                           dim_tags=novec_ftags,
-                          offset=_dof_offset(sf.test_element, sf.test_element_index),
+                          offset=_dof_offset(test_element, test_element_index),
                           )
-                alias_data_array(direct_output, sf.accumvar)
+                alias_data_array(direct_output, accumvar)
 
                 assignee = prim.Subscript(prim.Variable(direct_output), output_inames)
             else:
                 assert ft == 'jacobian'
 
-                direct_output = "{}x{}".format(direct_output, sf.trial_element_index)
-                rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element)))
-                element = sf.trial_element
+                direct_output = "{}x{}".format(direct_output, trial_element_index)
+                rowsize = sum(tuple(s for s in _local_sizes(trial_element)))
+                element = trial_element
                 if isinstance(element, MixedElement):
-                    element = element.extract_component(sf.trial_element_index)[1]
+                    element = element.extract_component(trial_element_index)[1]
                 other_shape = tuple(element.degree() + 1 for e in range(sf.length))
                 from pytools import product
                 manual_strides = tuple("stride:{}".format(rowsize * product(output_shape[:i])) for i in range(sf.length))
                 dim_tags = "{},{}".format(novec_ftags, ",".join(manual_strides))
                 globalarg(direct_output,
                           shape=other_shape + output_shape,
-                          offset=rowsize * _dof_offset(sf.test_element, sf.test_element_index) + _dof_offset(sf.trial_element, sf.trial_element_index),
+                          offset=rowsize * _dof_offset(test_element, test_element_index) + _dof_offset(trial_element, trial_element_index),
                           dim_tags=dim_tags,
                           )
-                alias_data_array(direct_output, sf.accumvar)
+                alias_data_array(direct_output, accumvar)
                 # TODO: It is at least questionnable, whether using the *order* of the inames in here
                 #       for indexing is a good idea. Then again, it is hard to find an alternative.
                 _ansatz_inames = tuple(prim.Variable(i) for i in sf.within_inames)
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index eb9df931..322806fa 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -30,22 +30,6 @@ class SumfactKernelInputBase(object):
         raise NotImplementedError
 
 
-class SumfactKernelOutputBase(object):
-    @property
-    def direct_output_is_possible(self):
-        return False
-
-    @property
-    def within_inames(self):
-        return ()
-
-    def realize(self, sf, dep):
-        return dep
-
-    def realize_direct(self):
-        raise NotImplementedError
-
-
 class VectorSumfactKernelInput(SumfactKernelInputBase):
     def __init__(self, inputs):
         assert(isinstance(inputs, tuple))
@@ -88,6 +72,27 @@ class VectorSumfactKernelInput(SumfactKernelInputBase):
             raise NotImplementedError("SIMD loads from scalars not implemented!")
 
 
+class SumfactKernelOutputBase(object):
+    @property
+    def direct_output_is_possible(self):
+        return False
+
+    @property
+    def within_inames(self):
+        return ()
+
+    def realize(self, sf, dep):
+        return dep
+
+    def realize_direct(self):
+        raise NotImplementedError
+
+
+class VectorSumfactKernelOutput(SumfactKernelOutputBase):
+    def __init__(self, outputs):
+        self.outputs = outputs
+
+
 class SumfactKernelBase(object):
     pass
 
@@ -508,31 +513,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     def within_inames(self):
         return self.kernels[0].within_inames
 
-    @property
-    def test_element(self):
-        return self.kernels[0].test_element
-
-    @property
-    def test_element_index(self):
-        return self.kernels[0].test_element_index
-
-    @property
-    def trial_element(self):
-        return self.kernels[0].trial_element
-
-    @property
-    def trial_element_index(self):
-        return self.kernels[0].trial_element_index
-
     @property
     def predicates(self):
         return self.kernels[0].predicates
 
-    @property
-    def accumvar(self):
-        assert len(set(k.accumvar for k in self.kernels)) == 1
-        return self.kernels[0].accumvar
-
     @property
     def transposed(self):
         return self.kernels[0].transposed
@@ -556,6 +540,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     def input(self):
         return VectorSumfactKernelInput(tuple(k.input for k in self.kernels))
 
+    @property
+    def output(self):
+        return VectorSumfactKernelOutput(tuple(k.output for k in self.kernels))
+
     @property
     def cache_key(self):
         return (tuple(k.cache_key for k in self.kernels), self.buffer)
-- 
GitLab