Skip to content
Snippets Groups Projects
Commit 9c12ecc4 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Remodel the generic_context

parent d95b1387
No related branches found
No related tags found
No related merge requests found
...@@ -27,6 +27,6 @@ from dune.perftool.generation.loopy import (domain, ...@@ -27,6 +27,6 @@ from dune.perftool.generation.loopy import (domain,
) )
from dune.perftool.generation.context import (cache_context, from dune.perftool.generation.context import (cache_context,
generic_context, global_context,
get_generic_context_value, get_global_context_value,
) )
...@@ -31,25 +31,34 @@ def get_context_tags(): ...@@ -31,25 +31,34 @@ def get_context_tags():
return result return result
_generic_context_cache = {} _global_context_cache = {}
class _GenericContext(object): class _GlobalContext(object):
def __init__(self, key, value): def __init__(self, **kwargs):
self.key = key self.kw = kwargs
self.value = value
assert key not in _generic_context_cache
def __enter__(self): def __enter__(self):
_generic_context_cache[self.key] = self.value self.old_kw = {}
for k, v in self.kw.items():
# First store existing values of the same keys
if k in _global_context_cache:
self.old_kw[k] = v
# Now replace the value with the new one
_global_context_cache[k] = v
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
del _generic_context_cache[self.key] # Delete all the entries from this context
for k in self.kw.keys():
del _global_context_cache[k]
# and restore previously overwritten values
for k, v in self.old_kw.items():
_global_context_cache[k] = v
def generic_context(key, value): def global_context(**kwargs):
return _GenericContext(key, value) return _GlobalContext(**kwargs)
def get_generic_context_value(key): def get_global_context_value(key):
return _generic_context_cache[key] return _global_context_cache[key]
...@@ -44,8 +44,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper): ...@@ -44,8 +44,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
assert isinstance(o, Expression) assert isinstance(o, Expression)
# Determine the name of the parameter function # Determine the name of the parameter function
from dune.perftool.generation import get_generic_context_value from dune.perftool.generation import get_global_context_value
name = get_generic_context_value("namedata")[id(o)] name = get_global_context_value("namedata")[id(o)]
# Trigger the generation of code for this thing in the parameter class # Trigger the generation of code for this thing in the parameter class
from dune.perftool.pdelab.parameter import parameter_function from dune.perftool.pdelab.parameter import parameter_function
...@@ -197,8 +197,8 @@ def transform_accumulation_term(term, measure, subdomain_id): ...@@ -197,8 +197,8 @@ def transform_accumulation_term(term, measure, subdomain_id):
# modified measure per integral type. # modified measure per integral type.
# Get the original form and inspect the present measures # Get the original form and inspect the present measures
from dune.perftool.generation import get_generic_context_value from dune.perftool.generation import get_global_context_value
original_form = get_generic_context_value("formdata").original_form original_form = get_global_context_value("formdata").original_form
sd = original_form.subdomain_data() sd = original_form.subdomain_data()
assert len(sd) == 1 assert len(sd) == 1
subdomains, = list(sd.values()) subdomains, = list(sd.values())
...@@ -208,10 +208,11 @@ def transform_accumulation_term(term, measure, subdomain_id): ...@@ -208,10 +208,11 @@ def transform_accumulation_term(term, measure, subdomain_id):
del subdomains[k] del subdomains[k]
# Finally extract the original subdomain_data (which needs to be present!) # Finally extract the original subdomain_data (which needs to be present!)
assert measure in subdomains
subdomain_data = subdomains[measure] subdomain_data = subdomains[measure]
# Determine the name of the parameter function # Determine the name of the parameter function
name = get_generic_context_value("namedata")[id(subdomain_data)] name = get_global_context_value("namedata")[id(subdomain_data)]
# Trigger the generation of code for this thing in the parameter class # Trigger the generation of code for this thing in the parameter class
from dune.perftool.pdelab.parameter import parameter_function from dune.perftool.pdelab.parameter import parameter_function
......
...@@ -243,16 +243,11 @@ def generate_localoperator_kernels(formdata, namedata): ...@@ -243,16 +243,11 @@ def generate_localoperator_kernels(formdata, namedata):
# Have a data structure collect the generated kernels # Have a data structure collect the generated kernels
operator_kernels = {} operator_kernels = {}
import functools from dune.perftool.generation import global_context
from dune.perftool.generation import generic_context with global_context(formdata=formdata, namedata=namedata):
namedata_context = functools.partial(generic_context, "namedata")
formdata_context = functools.partial(generic_context, "formdata")
with formdata_context(formdata):
# Generate the necessary residual methods # Generate the necessary residual methods
for integral in form.integrals(): for integral in form.integrals():
with namedata_context(namedata): kernel = generate_kernel(integral)
kernel = generate_kernel(integral)
operator_kernels[(integral.integral_type(), 'residual')] = kernel operator_kernels[(integral.integral_type(), 'residual')] = kernel
# Generate the necessary jacobian methods # Generate the necessary jacobian methods
...@@ -265,8 +260,7 @@ def generate_localoperator_kernels(formdata, namedata): ...@@ -265,8 +260,7 @@ def generate_localoperator_kernels(formdata, namedata):
jacform = expand_derivatives(derivative(form, form.coefficients()[0])) jacform = expand_derivatives(derivative(form, form.coefficients()[0]))
for integral in jacform.integrals(): for integral in jacform.integrals():
with namedata_context(namedata): kernel = generate_kernel(integral)
kernel = generate_kernel(integral)
operator_kernels[(integral.integral_type(), 'jacobian')] = kernel operator_kernels[(integral.integral_type(), 'jacobian')] = kernel
# TODO: JacobianApply for matrix-free computations. # TODO: JacobianApply for matrix-free computations.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment