From 97ad1a8b056b049b4d6d91438d61ae14d12cad93 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Fri, 2 Aug 2019 14:06:07 +0200 Subject: [PATCH] Implement metadata option parsing --- python/dune/codegen/pdelab/localoperator.py | 42 +++++++++++++-------- python/dune/codegen/pdelab/quadrature.py | 2 +- test/poisson/poisson_tensor.ufl | 3 ++ 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/python/dune/codegen/pdelab/localoperator.py b/python/dune/codegen/pdelab/localoperator.py index 600617ea..835fe44e 100644 --- a/python/dune/codegen/pdelab/localoperator.py +++ b/python/dune/codegen/pdelab/localoperator.py @@ -41,6 +41,7 @@ from dune.codegen.cgen.clazz import (AccessModifier, ) from dune.codegen.loopy.target import type_floatingpoint from dune.codegen.ufl.modified_terminals import Restriction +from frozendict import frozendict import dune.codegen.loopy.mangler @@ -509,24 +510,35 @@ def visit_integral(integral): def generate_kernel(integrals): logger = logging.getLogger(__name__) - # Visit all integrals once to collect information (dry-run)! - logger.debug('generate_kernel: visit_integrals (dry run)') - with global_context(dry_run=True): + # Assert that metadata for a given measure type agrees. This is a limitation + # of our current approach that is hard to overcome. + def remove_nonuser_metadata(d): + return frozendict({k: v for k, v in d.items() if k != "estimated_polynomial_degree"}) + + meta_dicts = [remove_nonuser_metadata(i.metadata()) for i in integrals] + if len(set(meta_dicts)) > 1: + measure = get_global_context_value("measure") + raise CodegenUFLError("Measure {} used with varying metadata! dune-codegen does not currently support this.") + + with form_option_context(**meta_dicts[0]): + # Visit all integrals once to collect information (dry-run)! + logger.debug('generate_kernel: visit_integrals (dry run)') + with global_context(dry_run=True): + for integral in integrals: + visit_integral(integral) + + # Now perform some checks on what should be done + from dune.codegen.sumfact.vectorization import decide_vectorization_strategy + logger.debug('generate_kernel: decide_vectorization_strategy') + decide_vectorization_strategy() + + # Delete the cache contents and do the real thing! + logger.debug('generate_kernel: visit_integrals (no dry run)') + from dune.codegen.generation import delete_cache_items + delete_cache_items("kernel_default") for integral in integrals: visit_integral(integral) - # Now perform some checks on what should be done - from dune.codegen.sumfact.vectorization import decide_vectorization_strategy - logger.debug('generate_kernel: decide_vectorization_strategy') - decide_vectorization_strategy() - - # Delete the cache contents and do the real thing! - logger.debug('generate_kernel: visit_integrals (no dry run)') - from dune.codegen.generation import delete_cache_items - delete_cache_items("kernel_default") - for integral in integrals: - visit_integral(integral) - from dune.codegen.pdelab.signatures import kernel_name, assembly_routine_signature name = kernel_name() signature = assembly_routine_signature() diff --git a/python/dune/codegen/pdelab/quadrature.py b/python/dune/codegen/pdelab/quadrature.py index 9814f6f2..ca44056c 100644 --- a/python/dune/codegen/pdelab/quadrature.py +++ b/python/dune/codegen/pdelab/quadrature.py @@ -203,7 +203,7 @@ def quadrature_order(): possible to use a different quadrature_order per direction. """ if get_form_option("quadrature_order"): - quadrature_order = tuple(map(int, get_form_option("quadrature_order").split(','))) + quadrature_order = tuple(map(int, str(get_form_option("quadrature_order")).split(','))) else: quadrature_order = _estimate_quadrature_order() diff --git a/test/poisson/poisson_tensor.ufl b/test/poisson/poisson_tensor.ufl index b527d052..3591d4d7 100644 --- a/test/poisson/poisson_tensor.ufl +++ b/test/poisson/poisson_tensor.ufl @@ -12,6 +12,9 @@ V = FiniteElement("CG", cell, 1) u = TrialFunction(V) v = TestFunction(V) +# Test metadata setting of options +dx = dx(metadata={"quadrature_order": 27}) + r= (inner(A*grad(u), grad(v)) + c*u*v -f*v)*dx exact_solution = g is_dirichlet = 1 -- GitLab