Skip to content
Snippets Groups Projects
Commit 25807a97 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Shrink the quadrature loop on intersection by one iname

parent 31f426a7
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,8 @@ from dune.perftool.sumfact.amatrix import (AMatrix, ...@@ -19,7 +19,8 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
name_theta, name_theta,
quadrature_points_per_direction, quadrature_points_per_direction,
) )
from dune.perftool.sumfact.sumfact import (setup_theta, from dune.perftool.sumfact.sumfact import (get_facedir,
setup_theta,
SumfactKernel, SumfactKernel,
sumfact_iname, sumfact_iname,
sum_factorization_kernel, sum_factorization_kernel,
...@@ -62,7 +63,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): ...@@ -62,7 +63,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
insn_dep = None insn_dep = None
for i in range(dim): for i in range(dim):
# Construct the matrix sequence for this sum factorization # Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(derivative=i) a_matrices = construct_amatrix_sequence(derivative=i, face=get_facedir(restriction))
# Get the vectorization info. If this happens during the dry run, we get dummies # Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info from dune.perftool.sumfact.vectorization import get_vectorization_info
...@@ -126,7 +127,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor): ...@@ -126,7 +127,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
dim = world_dimension() dim = world_dimension()
# Construct the matrix sequence for this sum factorization # Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence() a_matrices = construct_amatrix_sequence(face=get_facedir(restriction))
# Get the vectorization info. If this happens during the dry run, we get dummies # Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info from dune.perftool.sumfact.vectorization import get_vectorization_info
...@@ -153,6 +154,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor): ...@@ -153,6 +154,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
1, 1,
preferred_position=None, preferred_position=None,
insn_dep=frozenset({Writes(input)}), insn_dep=frozenset({Writes(input)}),
outshape=tuple(mat.rows for mat in a_matrices if mat.rows != 1),
) )
if index: if index:
...@@ -230,9 +232,7 @@ def evaluate_reference_gradient(element, name, restriction): ...@@ -230,9 +232,7 @@ def evaluate_reference_gradient(element, name, restriction):
calls[i] = prim.Subscript(prim.Variable(dtheta), (prim.Variable(quad_inames[i]), prim.Variable(inames[i]))) calls[i] = prim.Subscript(prim.Variable(dtheta), (prim.Variable(quad_inames[i]), prim.Variable(inames[i])))
calls = tuple(calls) calls = tuple(calls)
# assignee = prim.Subscript(prim.Variable(name), tuple(prim.Variable(0)))
assignee = prim.Subscript(prim.Variable(name), (i,)) assignee = prim.Subscript(prim.Variable(name), (i,))
# assignee = prim.Variable(name)
expression = prim.Product(calls) expression = prim.Product(calls)
instruction(assignee=assignee, instruction(assignee=assignee,
......
...@@ -71,15 +71,15 @@ def pymbolic_base_weight(): ...@@ -71,15 +71,15 @@ def pymbolic_base_weight():
@iname @iname
def sumfact_quad_iname(d, context): def sumfact_quad_iname(d, bound):
name = "quad_{}_{}".format(context, d) name = "quad_{}".format(d)
domain(name, quadrature_points_per_direction()) domain(name, quadrature_points_per_direction())
return name return name
@backend(interface="quad_inames", name="sumfact") @backend(interface="quad_inames", name="sumfact")
def quadrature_inames(context=''): def quadrature_inames():
return tuple(sumfact_quad_iname(d, context) for d in range(local_dimension())) return tuple(sumfact_quad_iname(d, quadrature_points_per_direction()) for d in range(local_dimension()))
def define_recursive_quadrature_weight(name, dir): def define_recursive_quadrature_weight(name, dir):
......
...@@ -34,6 +34,9 @@ from dune.perftool.pdelab.restriction import restricted_name ...@@ -34,6 +34,9 @@ from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.pdelab.spaces import (name_lfs, from dune.perftool.pdelab.spaces import (name_lfs,
name_lfs_bound, name_lfs_bound,
) )
from dune.perftool.pdelab.geometry import (local_dimension,
world_dimension,
)
from dune.perftool.sumfact.amatrix import (AMatrix, from dune.perftool.sumfact.amatrix import (AMatrix,
LargeAMatrix, LargeAMatrix,
quadrature_points_per_direction, quadrature_points_per_direction,
...@@ -59,6 +62,17 @@ import pymbolic.primitives as prim ...@@ -59,6 +62,17 @@ import pymbolic.primitives as prim
from pytools import product from pytools import product
def get_facedir(restriction):
from dune.perftool.pdelab.restriction import Restriction
if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facedir_s")
if restriction == Restriction.POSITIVE:
return get_global_context_value("facedir_n")
if restriction == Restriction.NONE:
return None
assert False
@iname @iname
def _sumfact_iname(bound, _type, count): def _sumfact_iname(bound, _type, count):
name = "sf_{}_{}".format(_type, str(count)) name = "sf_{}_{}".format(_type, str(count))
...@@ -101,9 +115,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -101,9 +115,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
if pymbolic_expr == 0: if pymbolic_expr == 0:
return return
# Get geometric dimension dim = world_dimension()
formdata = get_global_context_value('formdata') facedir = get_facedir(accterm.argument.restriction)
dim = formdata.geometric_dimension
# Collect buffers we need # Collect buffers we need
buffers = [] buffers = []
...@@ -125,6 +138,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -125,6 +138,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Construct the matrix sequence for this sum factorization # Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(transpose=True, a_matrices = construct_amatrix_sequence(transpose=True,
derivative=i if accterm.argument.index else None, derivative=i if accterm.argument.index else None,
face=facedir,
) )
# Get the vectorization info. If this happens during the dry run, we get dummies # Get the vectorization info. If this happens during the dry run, we get dummies
...@@ -132,8 +146,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -132,8 +146,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
a_matrices, buffer, input, index = get_vectorization_info(a_matrices) a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
# Initialize a base storage for this buffer and get a temporay pointing to it # Initialize a base storage for this buffer and get a temporay pointing to it
shape = tuple(mat.cols for mat in a_matrices) shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1)
dim_tags = ",".join(['f'] * dim) dim_tags = ",".join(['f'] * local_dimension())
if index is not None: if index is not None:
shape = shape + (4,) shape = shape + (4,)
dim_tags = dim_tags + ",c" dim_tags = dim_tags + ",c"
...@@ -226,7 +240,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -226,7 +240,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s)) @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s))
def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), preferred_position=None): def sum_factorization_kernel(a_matrices, buf, stage,
insn_dep=frozenset({}),
additional_inames=frozenset({}),
preferred_position=None,
outshape=None):
""" """
Calculate a sum factorization matrix product. Calculate a sum factorization matrix product.
...@@ -310,19 +328,16 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add ...@@ -310,19 +328,16 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add
) )
}) })
# Get geometric dimension if outshape is None:
formdata = get_global_context_value('formdata') outshape = tuple(mat.rows for mat in a_matrices)
dim = formdata.geometric_dimension dim_tags = ",".join(['f'] * len(outshape))
out_shape = tuple(mat.rows for mat in a_matrices)
dim_tags = ",".join(['f'] * dim)
if next(iter(a_matrices)).vectorized: if next(iter(a_matrices)).vectorized:
out_shape = out_shape + vec_shape outshape = outshape + vec_shape
dim_tags = dim_tags + ",c" dim_tags = dim_tags + ",c"
out = get_buffer_temporary(buf, out = get_buffer_temporary(buf,
shape=out_shape, shape=outshape,
dim_tags=dim_tags, dim_tags=dim_tags,
) )
silenced_warning('read_no_write({})'.format(out)) silenced_warning('read_no_write({})'.format(out))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment