Skip to content
Snippets Groups Projects
Commit e3195d96 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reorganize imports to have all backend implentations included

The backend selecting mechanism needs all modules to be imported.
This commit restructures the package to be able to do that. Placement
of some functions felt weird, but this can still be changed.
No more placements in __init__.py anymore though, as this will
result in cyclic dependencies.
parent 31fb9e3d
No related branches found
No related tags found
No related merge requests found
Showing
with 74 additions and 62 deletions
class Restriction:
NONE = 0
NEGATIVE = 1
POSITIVE = 2
from dune.perftool.options import get_option
# Trigger some imports that are needed to have all backend implementations visible
# to the selection mechanisms
import dune.perftool.pdelab
import dune.perftool.sumfact
""" The pdelab specific parts of the code generation process """
from dune.perftool.generation import (preamble,
cached,
)
# Now define some commonly used generators that do not fall into a specific category
@cached
def name_index(index):
from ufl.classes import MultiIndex, Index
if isinstance(index, Index):
# This failed for index > 9 because ufl placed curly brackets around
# return str(index)
return "i_{}".format(index.count())
if isinstance(index, MultiIndex):
assert len(index) == 1
# return str(index._indices[0])
return "i_{}".format(index._indices[0].count())
raise NotImplementedError
def restricted_name(name, restriction):
from dune.perftool import Restriction
if restriction == Restriction.NONE:
return name
if restriction == Restriction.POSITIVE:
return name + '_n'
if restriction == Restriction.NEGATIVE:
return name + '_s'
# Trigger some imports that are needed to have all backend implementations visible
# to the selection mechanisms
import dune.perftool.pdelab.argument
import dune.perftool.pdelab.basis
import dune.perftool.pdelab.driver
import dune.perftool.pdelab.geometry
import dune.perftool.pdelab.localoperator
import dune.perftool.pdelab.parameter
import dune.perftool.pdelab.quadrature
import dune.perftool.pdelab.signatures
import dune.perftool.pdelab.spaces
\ No newline at end of file
......@@ -8,15 +8,14 @@ Namely:
from dune.perftool.options import get_option
from dune.perftool.generation import (domain,
function_mangler,
get_backend,
iname,
pymbolic_expr,
globalarg,
valuearg,
get_global_context_value
)
from dune.perftool.pdelab import (name_index,
restricted_name,
)
from dune.perftool.pdelab.index import name_index
from dune.perftool.pdelab.basis import (evaluate_coefficient,
evaluate_coefficient_gradient,
name_basis,
......@@ -24,7 +23,8 @@ from dune.perftool.pdelab.basis import (evaluate_coefficient,
from dune.perftool.pdelab.spaces import (lfs_iname,
name_lfs_bound,
)
from dune.perftool import Restriction
from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.ufl.modified_terminals import Restriction
from pymbolic.primitives import Call, Subscript, Variable
......@@ -95,16 +95,6 @@ def name_trialfunction_gradient(element, restriction, component):
rawname = "gradu" + "_".join(str(c) for c in component)
name = restricted_name(rawname, restriction)
container = name_coefficientcontainer(restriction)
# TODO
#
# This is just a temporary test used to create an A-matrix as
# local operator class member. Right now it doesn't evaluate
# anything.
if get_option("sumfact") and restriction == Restriction.NONE:
from dune.perftool.sumfact import start_sumfactorization
start_sumfactorization(element, container, restriction, component)
evaluate_coefficient_gradient(element, name, container, restriction, component)
return name
......
......@@ -27,8 +27,7 @@ from dune.perftool.pdelab.localoperator import (lop_template_ansatz_gfs,
lop_template_test_gfs,
)
from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab import restricted_name
from dune.perftool.pdelab.restriction import restricted_name
from pymbolic.primitives import Product, Subscript, Variable
from loopy import Reduction
......
from dune.perftool import Restriction
from dune.perftool.pdelab import restricted_name
from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.generation import (cached,
domain,
get_global_context_value,
......
from dune.perftool.generation import cached
from ufl.classes import MultiIndex, Index
# Now define some commonly used generators that do not fall into a specific category
@cached
def name_index(index):
if isinstance(index, Index):
# This failed for index > 9 because ufl placed curly brackets around
# return str(index)
return "i_{}".format(index.count())
if isinstance(index, MultiIndex):
assert len(index) == 1
# return str(index._indices[0])
return "i_{}".format(index._indices[0].count())
raise NotImplementedError
......@@ -20,7 +20,7 @@ from dune.perftool.cgen.clazz import (AccessModifier,
BaseClass,
ClassMember,
)
from dune.perftool import Restriction
from dune.perftool.ufl.modified_terminals import Restriction
from pymbolic.primitives import Variable
from pytools import Record
......@@ -355,7 +355,7 @@ def boundary_predicates(expr, measure, subdomain_id):
@iname
def grad_iname(index, dim):
from dune.perftool.pdelab import name_index
from dune.perftool.pdelab.index import name_index
name = name_index(index)
domain(name, dim)
return name
......
from dune.perftool import Restriction
from dune.perftool.generation import (cached,
from dune.perftool.generation import (backend,
cached,
domain,
get_global_context_value,
iname,
......@@ -8,6 +8,7 @@ from dune.perftool.generation import (cached,
temporary_variable,
)
from dune.perftool.options import get_option
from dune.perftool.ufl.modified_terminals import Restriction
@iname
......
from dune.perftool.ufl.modified_terminals import Restriction
def restricted_name(name, restriction):
if restriction == Restriction.NONE:
return name
if restriction == Restriction.POSITIVE:
return name + '_n'
if restriction == Restriction.NEGATIVE:
return name + '_s'
\ No newline at end of file
""" Signatures for PDELab local opreator assembly functions """
from dune.perftool import Restriction
from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.pdelab.geometry import (name_geometry_wrapper,
type_geometry_wrapper,
)
......
......@@ -6,7 +6,7 @@ from dune.perftool.generation import (domain,
include_file,
preamble,
)
from dune.perftool.pdelab import restricted_name
from dune.perftool.pdelab.restriction import restricted_name
from loopy import CallMangleInfo
from loopy.symbolic import FunctionIdentifier
......
# Trigger some imports that are needed to have all backend implementations visible
# to the selection mechanisms
import dune.perftool.sumfact.amatrix
import dune.perftool.sumfact.quadrature
import dune.perftool.sumfact.sumfact
from dune.perftool.sumfact.sumfact import start_sumfactorization
from dune.perftool import Restriction
from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.options import get_option
......
""" A module mimicking some functionality of uflacs' modified terminals """
from ufl.algorithms import MultiFunction
from dune.perftool import Restriction
from ufl.classes import MultiIndex
from pytools import Record
class Restriction:
NONE = 0
NEGATIVE = 1
POSITIVE = 2
class ModifiedArgument(Record):
def __init__(self,
expr=None,
......
......@@ -3,8 +3,7 @@ This module defines the main visitor algorithm transforming ufl expressions
to pymbolic and loopy.
"""
from dune.perftool import Restriction
from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker, Restriction
from dune.perftool.generation import (domain,
get_temporary_name,
global_context,
......@@ -185,7 +184,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
ind = o.ufl_operands[1][0]
redinames = additional_inames + (ind,)
shape = o.ufl_operands[0].ufl_index_dimensions[0]
from dune.perftool.pdelab import name_index
from dune.perftool.pdelab.index import name_index
domain(name_index(ind), shape)
# If the left operand is an index sum to, we do it in one reduction
......@@ -211,7 +210,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return index._value
else:
from pymbolic.primitives import Variable
from dune.perftool.pdelab import name_index
from dune.perftool.pdelab.index import name_index
if index in self.dimension_indices:
from dune.perftool.pdelab.geometry import dimension_iname
self.inames.append(self.dimension_indices[index])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment