Skip to content
Snippets Groups Projects
Commit 1529cb73 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Precompute quadrature weights on a localoperator level

parent 6fb27fd0
No related branches found
No related tags found
No related merge requests found
......@@ -65,6 +65,7 @@ class PerftoolOptionsArray(ImmutableRecord):
# Arguments that are mainly to be set by logic depending on other options
max_vector_width = PerftoolOption(default=256, helpstr=None)
unroll_dimension_loops = PerftoolOption(default=False, helpstr="whether loops over the gemetric dimension should be unrolled.")
precompute_quadrature_info = PerftoolOption(default=True, helpstr="whether loops over the gemetric dimension should be unrolled.")
# Until more sophisticated logic is needed, we keep the actual option data in this module
......
......@@ -414,19 +414,6 @@ def visit_integrals(integrals):
if name.startswith("cse"):
set_subst_rule(name, expr)
# Ensure CSE on detjac * quadrature weight
domain = accterm.term.ufl_domain()
if measure == "cell":
set_subst_rule("integration_factor_cell1",
uc.QuadratureWeight(domain) * uc.Abs(uc.JacobianDeterminant(domain)))
set_subst_rule("integration_factor_cell2",
uc.Abs(uc.JacobianDeterminant(domain)) * uc.QuadratureWeight(domain))
else:
set_subst_rule("integration_factor_facet1",
uc.FacetJacobianDeterminant(domain) * uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_facet2",
uc.QuadratureWeight(domain) * uc.FacetJacobianDeterminant(domain))
get_backend(interface="accum_insn")(visitor, accterm, measure, subdomain_id)
......
......@@ -4,6 +4,7 @@ from dune.perftool.generation import (backend,
get_global_context_value,
iname,
instruction,
loopy_class_member,
temporary_variable,
)
......@@ -15,6 +16,7 @@ from dune.perftool.pdelab.argument import name_accumulation_variable
from dune.perftool.pdelab.geometry import (dimension_iname,
local_dimension,
)
from dune.perftool.options import get_option
from loopy import CallMangleInfo
from loopy.symbolic import FunctionIdentifier
......@@ -26,7 +28,8 @@ from pymbolic.primitives import (Call,
Variable,
)
import numpy
import pymbolic.primitives as prim
import numpy as np
def nest_quadrature_loops(kernel, inames):
......@@ -58,7 +61,7 @@ class BaseWeight(FunctionIdentifier):
@function_mangler
def base_weight_function_mangler(target, func, dtypes):
if isinstance(func, BaseWeight):
return CallMangleInfo(func.name, (NumpyType(numpy.float64),), ())
return CallMangleInfo(func.name, (NumpyType(np.float64),), ())
def pymbolic_base_weight():
......@@ -82,6 +85,17 @@ def quadrature_inames():
return tuple(sumfact_quad_iname(d, quadrature_points_per_direction()) for d in range(local_dimension()))
@iname(kernel="operator")
def constructor_quad_iname(name, d, bound):
name = "{}_{}".format(name, d)
domain(name, quadrature_points_per_direction(), kernel="operator")
return name
def constructor_quadrature_inames(name):
return tuple(constructor_quad_iname(name, d, quadrature_points_per_direction()) for d in range(local_dimension()))
def define_recursive_quadrature_weight(name, dir):
iname = quadrature_inames()[dir]
temporary_variable(name, shape=(), shape_impl=())
......@@ -107,7 +121,32 @@ def recursive_quadrature_weight(dir=0):
def quadrature_weight():
return recursive_quadrature_weight()
# Return non-precomputed version
if not get_option("precompute_quadrature_info"):
return recursive_quadrature_weight()
dim = local_dimension()
num1d = quadrature_points_per_direction()
name = "quad_weights_dim{}_num{}".format(dim, num1d)
# Add a class member
loopy_class_member(name,
dtype=np.float64,
shape=(num1d,) * dim,
classtag="operator",
dim_tags=",".join(["c"] * dim),
managed=True,
potentially_vectorized=True,
)
# Precompute it in the constructor
instruction(assignee=prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in constructor_quadrature_inames(name))),
expression=prim.Product(tuple(Subscript(Variable(name_oned_quadrature_weights()), (prim.Variable(i),)) for i in constructor_quadrature_inames(name))),
within_inames=frozenset(constructor_quadrature_inames(name)),
kernel="operator",
)
return prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in quadrature_inames()))
def define_quadrature_position(name):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment