Skip to content
Snippets Groups Projects
Commit e3195d96 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reorganize imports to have all backend implentations included

The backend selecting mechanism needs all modules to be imported.
This commit restructures the package to be able to do that. Placement
of some functions felt weird, but this can still be changed.
No more placements in __init__.py anymore though, as this will
result in cyclic dependencies.
parent 31fb9e3d
No related branches found
No related tags found
No related merge requests found
Showing
with 74 additions and 62 deletions
class Restriction: from dune.perftool.options import get_option
NONE = 0
NEGATIVE = 1 # Trigger some imports that are needed to have all backend implementations visible
POSITIVE = 2 # to the selection mechanisms
import dune.perftool.pdelab
import dune.perftool.sumfact
""" The pdelab specific parts of the code generation process """ """ The pdelab specific parts of the code generation process """
from dune.perftool.generation import (preamble, # Trigger some imports that are needed to have all backend implementations visible
cached, # to the selection mechanisms
) import dune.perftool.pdelab.argument
import dune.perftool.pdelab.basis
import dune.perftool.pdelab.driver
# Now define some commonly used generators that do not fall into a specific category import dune.perftool.pdelab.geometry
@cached import dune.perftool.pdelab.localoperator
def name_index(index): import dune.perftool.pdelab.parameter
from ufl.classes import MultiIndex, Index import dune.perftool.pdelab.quadrature
if isinstance(index, Index): import dune.perftool.pdelab.signatures
# This failed for index > 9 because ufl placed curly brackets around import dune.perftool.pdelab.spaces
# return str(index) \ No newline at end of file
return "i_{}".format(index.count())
if isinstance(index, MultiIndex):
assert len(index) == 1
# return str(index._indices[0])
return "i_{}".format(index._indices[0].count())
raise NotImplementedError
def restricted_name(name, restriction):
from dune.perftool import Restriction
if restriction == Restriction.NONE:
return name
if restriction == Restriction.POSITIVE:
return name + '_n'
if restriction == Restriction.NEGATIVE:
return name + '_s'
...@@ -8,15 +8,14 @@ Namely: ...@@ -8,15 +8,14 @@ Namely:
from dune.perftool.options import get_option from dune.perftool.options import get_option
from dune.perftool.generation import (domain, from dune.perftool.generation import (domain,
function_mangler, function_mangler,
get_backend,
iname, iname,
pymbolic_expr, pymbolic_expr,
globalarg, globalarg,
valuearg, valuearg,
get_global_context_value get_global_context_value
) )
from dune.perftool.pdelab import (name_index, from dune.perftool.pdelab.index import name_index
restricted_name,
)
from dune.perftool.pdelab.basis import (evaluate_coefficient, from dune.perftool.pdelab.basis import (evaluate_coefficient,
evaluate_coefficient_gradient, evaluate_coefficient_gradient,
name_basis, name_basis,
...@@ -24,7 +23,8 @@ from dune.perftool.pdelab.basis import (evaluate_coefficient, ...@@ -24,7 +23,8 @@ from dune.perftool.pdelab.basis import (evaluate_coefficient,
from dune.perftool.pdelab.spaces import (lfs_iname, from dune.perftool.pdelab.spaces import (lfs_iname,
name_lfs_bound, name_lfs_bound,
) )
from dune.perftool import Restriction from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.ufl.modified_terminals import Restriction
from pymbolic.primitives import Call, Subscript, Variable from pymbolic.primitives import Call, Subscript, Variable
...@@ -95,16 +95,6 @@ def name_trialfunction_gradient(element, restriction, component): ...@@ -95,16 +95,6 @@ def name_trialfunction_gradient(element, restriction, component):
rawname = "gradu" + "_".join(str(c) for c in component) rawname = "gradu" + "_".join(str(c) for c in component)
name = restricted_name(rawname, restriction) name = restricted_name(rawname, restriction)
container = name_coefficientcontainer(restriction) container = name_coefficientcontainer(restriction)
# TODO
#
# This is just a temporary test used to create an A-matrix as
# local operator class member. Right now it doesn't evaluate
# anything.
if get_option("sumfact") and restriction == Restriction.NONE:
from dune.perftool.sumfact import start_sumfactorization
start_sumfactorization(element, container, restriction, component)
evaluate_coefficient_gradient(element, name, container, restriction, component) evaluate_coefficient_gradient(element, name, container, restriction, component)
return name return name
......
...@@ -27,8 +27,7 @@ from dune.perftool.pdelab.localoperator import (lop_template_ansatz_gfs, ...@@ -27,8 +27,7 @@ from dune.perftool.pdelab.localoperator import (lop_template_ansatz_gfs,
lop_template_test_gfs, lop_template_test_gfs,
) )
from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.pdelab import restricted_name
from pymbolic.primitives import Product, Subscript, Variable from pymbolic.primitives import Product, Subscript, Variable
from loopy import Reduction from loopy import Reduction
......
from dune.perftool import Restriction from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.pdelab import restricted_name from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.generation import (cached, from dune.perftool.generation import (cached,
domain, domain,
get_global_context_value, get_global_context_value,
......
from dune.perftool.generation import cached
from ufl.classes import MultiIndex, Index
# Now define some commonly used generators that do not fall into a specific category
@cached
def name_index(index):
if isinstance(index, Index):
# This failed for index > 9 because ufl placed curly brackets around
# return str(index)
return "i_{}".format(index.count())
if isinstance(index, MultiIndex):
assert len(index) == 1
# return str(index._indices[0])
return "i_{}".format(index._indices[0].count())
raise NotImplementedError
...@@ -20,7 +20,7 @@ from dune.perftool.cgen.clazz import (AccessModifier, ...@@ -20,7 +20,7 @@ from dune.perftool.cgen.clazz import (AccessModifier,
BaseClass, BaseClass,
ClassMember, ClassMember,
) )
from dune.perftool import Restriction from dune.perftool.ufl.modified_terminals import Restriction
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
from pytools import Record from pytools import Record
...@@ -355,7 +355,7 @@ def boundary_predicates(expr, measure, subdomain_id): ...@@ -355,7 +355,7 @@ def boundary_predicates(expr, measure, subdomain_id):
@iname @iname
def grad_iname(index, dim): def grad_iname(index, dim):
from dune.perftool.pdelab import name_index from dune.perftool.pdelab.index import name_index
name = name_index(index) name = name_index(index)
domain(name, dim) domain(name, dim)
return name return name
......
from dune.perftool import Restriction from dune.perftool.generation import (backend,
from dune.perftool.generation import (cached, cached,
domain, domain,
get_global_context_value, get_global_context_value,
iname, iname,
...@@ -8,6 +8,7 @@ from dune.perftool.generation import (cached, ...@@ -8,6 +8,7 @@ from dune.perftool.generation import (cached,
temporary_variable, temporary_variable,
) )
from dune.perftool.options import get_option from dune.perftool.options import get_option
from dune.perftool.ufl.modified_terminals import Restriction
@iname @iname
......
from dune.perftool.ufl.modified_terminals import Restriction
def restricted_name(name, restriction):
if restriction == Restriction.NONE:
return name
if restriction == Restriction.POSITIVE:
return name + '_n'
if restriction == Restriction.NEGATIVE:
return name + '_s'
\ No newline at end of file
""" Signatures for PDELab local opreator assembly functions """ """ Signatures for PDELab local opreator assembly functions """
from dune.perftool import Restriction from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.pdelab.geometry import (name_geometry_wrapper, from dune.perftool.pdelab.geometry import (name_geometry_wrapper,
type_geometry_wrapper, type_geometry_wrapper,
) )
......
...@@ -6,7 +6,7 @@ from dune.perftool.generation import (domain, ...@@ -6,7 +6,7 @@ from dune.perftool.generation import (domain,
include_file, include_file,
preamble, preamble,
) )
from dune.perftool.pdelab import restricted_name from dune.perftool.pdelab.restriction import restricted_name
from loopy import CallMangleInfo from loopy import CallMangleInfo
from loopy.symbolic import FunctionIdentifier from loopy.symbolic import FunctionIdentifier
......
# Trigger some imports that are needed to have all backend implementations visible
# to the selection mechanisms
import dune.perftool.sumfact.amatrix
import dune.perftool.sumfact.quadrature
import dune.perftool.sumfact.sumfact
from dune.perftool.sumfact.sumfact import start_sumfactorization from dune.perftool.sumfact.sumfact import start_sumfactorization
from dune.perftool import Restriction from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.options import get_option from dune.perftool.options import get_option
......
""" A module mimicking some functionality of uflacs' modified terminals """ """ A module mimicking some functionality of uflacs' modified terminals """
from ufl.algorithms import MultiFunction from ufl.algorithms import MultiFunction
from dune.perftool import Restriction
from ufl.classes import MultiIndex from ufl.classes import MultiIndex
from pytools import Record from pytools import Record
class Restriction:
NONE = 0
NEGATIVE = 1
POSITIVE = 2
class ModifiedArgument(Record): class ModifiedArgument(Record):
def __init__(self, def __init__(self,
expr=None, expr=None,
......
...@@ -3,8 +3,7 @@ This module defines the main visitor algorithm transforming ufl expressions ...@@ -3,8 +3,7 @@ This module defines the main visitor algorithm transforming ufl expressions
to pymbolic and loopy. to pymbolic and loopy.
""" """
from dune.perftool import Restriction from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker, Restriction
from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
from dune.perftool.generation import (domain, from dune.perftool.generation import (domain,
get_temporary_name, get_temporary_name,
global_context, global_context,
...@@ -185,7 +184,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -185,7 +184,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
ind = o.ufl_operands[1][0] ind = o.ufl_operands[1][0]
redinames = additional_inames + (ind,) redinames = additional_inames + (ind,)
shape = o.ufl_operands[0].ufl_index_dimensions[0] shape = o.ufl_operands[0].ufl_index_dimensions[0]
from dune.perftool.pdelab import name_index from dune.perftool.pdelab.index import name_index
domain(name_index(ind), shape) domain(name_index(ind), shape)
# If the left operand is an index sum to, we do it in one reduction # If the left operand is an index sum to, we do it in one reduction
...@@ -211,7 +210,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -211,7 +210,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return index._value return index._value
else: else:
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
from dune.perftool.pdelab import name_index from dune.perftool.pdelab.index import name_index
if index in self.dimension_indices: if index in self.dimension_indices:
from dune.perftool.pdelab.geometry import dimension_iname from dune.perftool.pdelab.geometry import dimension_iname
self.inames.append(self.dimension_indices[index]) self.inames.append(self.dimension_indices[index])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment