Skip to content
Snippets Groups Projects
Commit f1d736fd authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Introduce a context that alters options and resets them on exit

A pattern that we have been manually implementing in quite some places!
parent ff344d8d
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
from argparse import ArgumentParser
from os.path import abspath
from pytools import ImmutableRecord, memoize
from contextlib import contextmanager
from dune.testtools.parametertree.parser import parse_ini_file
......@@ -295,3 +296,41 @@ def get_form_option(key, form=None):
form = get_option("operators").split(",")[form].strip()
processed_form_opts = process_form_options(_form_options[form], form)
return getattr(processed_form_opts, key)
@contextmanager
def option_context(conditional=True, **opts):
""" A context manager that sets a given option and restores it on exit. """
# Backup old values and set to new ones
if conditional:
backup = {}
for k, v in opts.items():
backup[k] = get_option(k)
set_option(k, v)
yield
if conditional:
# Restore old values
for k in opts.keys():
set_option(k, backup[k])
@contextmanager
def form_option_context(conditional=True, **opts):
""" A context manager that sets a given form option and restores it on exit """
if conditional:
form = opts.pop("form", None)
# Backup old values and set to new ones
backup = {}
for k, v in opts.items():
backup[k] = get_form_option(k, form=form)
set_form_option(k, v, form=form)
yield
# Restore old values
if conditional:
for k in opts.keys():
set_form_option(k, backup[k], form=form)
......@@ -7,7 +7,8 @@ import numpy as np
from dune.codegen.options import (get_form_option,
get_option,
set_form_option)
form_option_context,
)
from dune.codegen.generation import (accumulation_mixin,
base_class,
class_basename,
......@@ -1024,13 +1025,9 @@ def generate_jacobian_kernels(form, original_form):
operator_kernels[(it, 'jacobian_apply')] = [LoopyKernelMethod(assembly_routine_signature(), kernel=None)]
if get_form_option("generate_jacobians"):
with global_context(form_type="jacobian"):
if get_form_option("generate_jacobians"):
if get_form_option("sumfact"):
was_sumfact = True
if get_form_option("sumfact_regular_jacobians"):
old_geometry_mixins = get_form_option("geometry_mixins")
set_form_option("geometry_mixins", "generic")
set_form_option("sumfact", False)
with form_option_context(conditional=get_form_option("sumfact") and get_form_option("sumfact_regular_jacobians"),
geometry_mixins="generic",
sumfact=False):
for measure in set(i.integral_type() for i in jacform.integrals()):
if not measure_is_enabled(measure):
continue
......@@ -1057,10 +1054,6 @@ def generate_jacobian_kernels(form, original_form):
with global_context(integral_type=it):
from dune.codegen.pdelab.signatures import assembly_routine_signature
operator_kernels[(it, 'jacobian')] = [LoopyKernelMethod(assembly_routine_signature(), kernel=None)]
if get_form_option("sumfact_regular_jacobians"):
if was_sumfact:
set_form_option("sumfact", True)
set_form_option("geometry_mixins", old_geometry_mixins)
return operator_kernels
......
This diff is collapsed.
......@@ -50,8 +50,13 @@ import numpy as np
import loopy as lp
class SumfactGeometryMixinBase(GenericPDELabGeometryMixin):
def nonsumfact_fallback(self):
return None
@geometry_mixin("sumfact_multilinear")
class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin):
class SumfactMultiLinearGeometryMixin(SumfactGeometryMixinBase):
def nonsumfact_fallback(self):
return "generic"
......@@ -241,7 +246,7 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin):
@geometry_mixin("sumfact_axiparallel")
class SumfactAxiParallelGeometryMixin(AxiparallelGeometryMixin):
class SumfactAxiParallelGeometryMixin(SumfactGeometryMixinBase, AxiparallelGeometryMixin):
def nonsumfact_fallback(self):
return "axiparallel"
......
......@@ -12,7 +12,7 @@ from dune.codegen.pdelab.signatures import (assembly_routine_args,
assembly_routine_signature,
kernel_name,
)
from dune.codegen.options import get_form_option, get_option, set_form_option
from dune.codegen.options import get_form_option, get_option, form_option_context
from dune.codegen.cgen.clazz import ClassMember
......@@ -26,21 +26,12 @@ def sumfact_generate_kernels_per_integral(integrals):
if measure == "exterior_facet":
# Maybe skip sum factorization on boundary integrals
if not get_form_option("sumfact_on_boundary"):
set_form_option("sumfact", False)
# Try to find a fallback for sum factorized geometry mixins
geometry_backup = get_form_option("geometry_mixins")
mixin = construct_from_mixins(mixins=[geometry_backup])()
if hasattr(mixin, "nonsumfact_fallback"):
set_form_option("geometry_mixins", mixin.nonsumfact_fallback())
for k in generate_kernels_per_integral(integrals):
yield k
# Reset state
set_form_option("geometry_mixins", geometry_backup)
set_form_option("sumfact", True)
return
mixin = construct_from_mixins(mixins=[get_form_option("geometry_mixins")])()
geometry = mixin.nonsumfact_fallback() or get_form_option("geometry_mixins")
with form_option_context(sumfact=False, geometry_mixins=geometry):
for k in generate_kernels_per_integral(integrals):
yield k
return
# Generate all necessary kernels
for facedir in range(dim):
......
......@@ -20,7 +20,7 @@ from dune.codegen.sumfact.tabulation import (quadrature_points_per_direction,
set_quadrature_points,
)
from dune.codegen.error import CodegenVectorizationError
from dune.codegen.options import get_form_option, get_option, set_form_option
from dune.codegen.options import get_form_option, get_option, form_option_context
from dune.codegen.tools import add_to_frozendict, round_to_multiple, list_diff
from pymbolic.mapper.flop_counter import FlopCounter
......@@ -331,17 +331,16 @@ def level1_optimal_vectorization_strategy(sumfacts, width):
# If we are using the 'target' strategy, we might want to log some information.
if get_form_option("vectorization_strategy") == "target":
# Print the achieved cost and the target cost on the screen
set_form_option("vectorization_strategy", "model")
target = float(get_form_option("vectorization_target"))
qp = min(optimal_strategies, key=lambda qp: abs(strategy_cost((qp, optimal_strategies[qp])) - target))
cost = strategy_cost((qp, optimal_strategies[qp]))
print("The target cost was: {}".format(target))
print("The achieved cost was: {}".format(cost))
optimum = level1_optimal_vectorization_strategy(sumfacts, width)
print("The optimal cost would be: {}".format(strategy_cost(optimum)))
set_form_option("vectorization_strategy", "target")
print("The score in 'target' logic was: {}".format(strategy_cost((qp, optimal_strategies[qp]))))
with form_option_context(vectorization_strategy="model"):
target = float(get_form_option("vectorization_target"))
qp = min(optimal_strategies, key=lambda qp: abs(strategy_cost((qp, optimal_strategies[qp])) - target))
cost = strategy_cost((qp, optimal_strategies[qp]))
print("The target cost was: {}".format(target))
print("The achieved cost was: {}".format(cost))
optimum = level1_optimal_vectorization_strategy(sumfacts, width)
print("The optimal cost would be: {}".format(strategy_cost(optimum)))
print("The score in 'target' logic was: {}".format(strategy_cost((qp, optimal_strategies[qp]))))
# Print the employed vectorization strategy into a file
suffix = ""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment