Skip to content
Snippets Groups Projects
Commit 37a95a8e authored by Marcel Koch's avatar Marcel Koch
Browse files

use mixin to get quadrature inames

parent 2449e1d8
No related branches found
No related tags found
No related merge requests found
...@@ -307,7 +307,7 @@ def name_jacobian_inverse_transposed(restriction): ...@@ -307,7 +307,7 @@ def name_jacobian_inverse_transposed(restriction):
# translate a point in the micro element into macro coordinates # translate a point in the micro element into macro coordinates
def define_point_in_macro(name, point_in_micro): def define_point_in_macro(name, point_in_micro, visitor):
dim = local_dimension() dim = local_dimension()
if get_form_option('vectorization_blockstructured'): if get_form_option('vectorization_blockstructured'):
temporary_variable(name, shape=(dim,), managed=True) temporary_variable(name, shape=(dim,), managed=True)
...@@ -327,17 +327,17 @@ def define_point_in_macro(name, point_in_micro): ...@@ -327,17 +327,17 @@ def define_point_in_macro(name, point_in_micro):
# TODO relax within inames # TODO relax within inames
instruction(assignee=prim.Subscript(prim.Variable(name), (i,)), instruction(assignee=prim.Subscript(prim.Variable(name), (i,)),
expression=expr, expression=expr,
within_inames=frozenset(subelem_inames + get_backend(interface="quad_inames")()), within_inames=frozenset(subelem_inames + visitor.quadrature_inames()),
tags=frozenset({subelem_inames[i]}) tags=frozenset({subelem_inames[i]})
) )
# TODO add subelem inames if this function gets called # TODO add subelem inames if this function gets called
# TODO change input parameter to string # TODO change input parameter to string
def name_point_in_macro(point_in_micro): def name_point_in_macro(point_in_micro, visitor):
assert isinstance(point_in_micro, prim.Expression) assert isinstance(point_in_micro, prim.Expression)
name = get_pymbolic_basename(point_in_micro) + "_macro" name = get_pymbolic_basename(point_in_micro) + "_macro"
define_point_in_macro(name, point_in_micro) define_point_in_macro(name, point_in_micro, visitor)
return name return name
......
...@@ -8,9 +8,13 @@ import pymbolic.primitives as prim ...@@ -8,9 +8,13 @@ import pymbolic.primitives as prim
@quadrature_mixin("blockstructured") @quadrature_mixin("blockstructured")
class BlockstructuredQuadratureMixin(GenericQuadratureMixin): class BlockstructuredQuadratureMixin(GenericQuadratureMixin):
def quadrature_position(self): def quadrature_position(self, index=None):
original = GenericQuadratureMixin.quadrature_position(self) original = GenericQuadratureMixin.quadrature_position(self)
return prim.Variable(name_point_in_macro(original)) qp = prim.Variable(name_point_in_macro(original, self), )
if index is not None:
return prim.Subscript(qp, (index,))
else:
return qp
# #
# @backend(interface="quad_pos", name='blockstructured') # @backend(interface="quad_pos", name='blockstructured')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment