diff --git a/python/dune/codegen/blockstructured/geometry.py b/python/dune/codegen/blockstructured/geometry.py index f19202ec8ed3a5b527302edab638d841b336663e..a2bfd1de9bed8363e693b44e4270f4de5916db90 100644 --- a/python/dune/codegen/blockstructured/geometry.py +++ b/python/dune/codegen/blockstructured/geometry.py @@ -307,7 +307,7 @@ def name_jacobian_inverse_transposed(restriction): # translate a point in the micro element into macro coordinates -def define_point_in_macro(name, point_in_micro): +def define_point_in_macro(name, point_in_micro, visitor): dim = local_dimension() if get_form_option('vectorization_blockstructured'): temporary_variable(name, shape=(dim,), managed=True) @@ -327,17 +327,17 @@ def define_point_in_macro(name, point_in_micro): # TODO relax within inames instruction(assignee=prim.Subscript(prim.Variable(name), (i,)), expression=expr, - within_inames=frozenset(subelem_inames + get_backend(interface="quad_inames")()), + within_inames=frozenset(subelem_inames + visitor.quadrature_inames()), tags=frozenset({subelem_inames[i]}) ) # TODO add subelem inames if this function gets called # TODO change input parameter to string -def name_point_in_macro(point_in_micro): +def name_point_in_macro(point_in_micro, visitor): assert isinstance(point_in_micro, prim.Expression) name = get_pymbolic_basename(point_in_micro) + "_macro" - define_point_in_macro(name, point_in_micro) + define_point_in_macro(name, point_in_micro, visitor) return name diff --git a/python/dune/codegen/blockstructured/quadrature.py b/python/dune/codegen/blockstructured/quadrature.py index d2875435fb72345a3bd87ee7e23983edbd9343fe..f69ddae816d876fd69a45a4e80844d58bee43bf6 100644 --- a/python/dune/codegen/blockstructured/quadrature.py +++ b/python/dune/codegen/blockstructured/quadrature.py @@ -8,9 +8,13 @@ import pymbolic.primitives as prim @quadrature_mixin("blockstructured") class BlockstructuredQuadratureMixin(GenericQuadratureMixin): - def quadrature_position(self): + def quadrature_position(self, index=None): original = GenericQuadratureMixin.quadrature_position(self) - return prim.Variable(name_point_in_macro(original)) + qp = prim.Variable(name_point_in_macro(original, self), ) + if index is not None: + return prim.Subscript(qp, (index,)) + else: + return qp # # @backend(interface="quad_pos", name='blockstructured')