Skip to content
Snippets Groups Projects
Commit b0573c70 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Remove interface

parent d7621d9f
No related branches found
No related tags found
No related merge requests found
......@@ -2,37 +2,8 @@
# Trigger some imports that are needed to have all backend implementations visible
# to the selection mechanisms
from dune.codegen.generation import (get_backend)
from dune.codegen.options import option_switch
from dune.codegen.pdelab.argument import pymbolic_coefficient
from dune.codegen.pdelab.function import pymbolic_gridfunction
from dune.codegen.pdelab.index import name_index
from dune.codegen.pdelab.spaces import (lfs_inames,
)
from dune.codegen.pdelab.tensors import (pymbolic_list_tensor,
pymbolic_identity,
pymbolic_matrix_inverse,
)
class PDELabInterface(object):
def __init__(self):
# The visitor instance will be registered by its init method
self.visitor = None
def initialize_function_spaces(self, expr, visitor):
from dune.codegen.pdelab.spaces import initialize_function_spaces
return initialize_function_spaces(expr, visitor)
#
# Tensor expression related generator functions
#
def pymbolic_list_tensor(self, o):
return pymbolic_list_tensor(o, self.visitor)
def pymbolic_identity(self, o):
return pymbolic_identity(o)
def pymbolic_matrix_inverse(self, o):
return pymbolic_matrix_inverse(o, self.visitor)
......@@ -23,6 +23,7 @@ from dune.codegen.pdelab.spaces import (lfs_iname,
name_lfs_bound,
type_gfs,
type_leaf_gfs,
initialize_function_spaces,
)
from dune.codegen.pdelab.geometry import (component_iname,
world_dimension,
......@@ -51,6 +52,9 @@ from loopy import Reduction
@basis_mixin("base")
class BasisMixinBase(object):
def initialize_function_spaces(self, expr):
pass
def lfs_inames(self, element, restriction, number):
raise NotImplementedError("Basis Mixins should implement local function space inames")
......@@ -78,6 +82,9 @@ class BasisMixinBase(object):
@basis_mixin("generic")
class GenericBasisMixin(BasisMixinBase):
def initialize_function_spaces(self, expr):
return initialize_function_spaces(expr, self)
def lfs_inames(self, element, restriction, number, context=""):
return (lfs_iname(element, restriction, number, context),)
......
......@@ -14,61 +14,6 @@ import loopy as lp
import itertools as it
def define_list_tensor(name, expr, visitor, stack=()):
for i, child in enumerate(expr.ufl_operands):
from ufl.classes import ListTensor
if isinstance(child, ListTensor):
define_list_tensor(name, child, visitor, stack=stack + (i,))
else:
visexpr = visitor.call(child)
from loopy.symbolic import DependencyMapper
deps = DependencyMapper(include_subscripts=False, include_lookups=False, include_calls=False)(visexpr)
instruction(assignee=prim.Subscript(prim.Variable(name), stack + (i,)),
expression=visitor.call(child),
forced_iname_deps=frozenset(visitor.interface.quadrature_inames()),
depends_on=frozenset({lp.match.Tagged("sumfact_stage1")}),
tags=frozenset({"quad"}),
)
@kernel_cached
def pymbolic_list_tensor(expr, visitor):
name = get_counted_variable("listtensor")
temporary_variable(name,
shape=expr.ufl_shape,
managed=True,
)
define_list_tensor(name, expr, visitor)
return prim.Variable(name)
@iname
def identity_iname(name, bound):
name = "id_{}_{}".format(name, bound)
domain(name, bound)
return name
def define_identity(name, expr):
i = identity_iname("i", expr.ufl_shape[0])
j = identity_iname("j", expr.ufl_shape[1])
instruction(assignee=prim.Subscript(prim.Variable(name), (prim.Variable(i), prim.Variable(j))),
expression=prim.If(prim.Comparison(prim.Variable(i), "==", prim.Variable(j)), 1, 0),
forced_iname_deps_is_final=True,
)
@kernel_cached
def pymbolic_identity(expr):
name = "identity_{}_{}".format(expr.ufl_shape[0], expr.ufl_shape[1])
temporary_variable(name,
shape=expr.ufl_shape,
shape_impl=('fm',),
)
define_identity(name, expr)
return prim.Variable(name)
def define_assembled_tensor(name, expr, visitor):
temporary_variable(name,
shape=expr.ufl_shape,
......
......@@ -94,7 +94,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
#
def argument(self, o):
self.interface.initialize_function_spaces(o, self)
self.initialize_function_spaces(o)
# Update the information on where to accumulate this
info = self.get_accumulation_info(o)
if o.number() == 0:
......@@ -143,7 +143,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# Do something different for trial function and coefficients from jacobian apply
if o.count() == 0 or o.count() == 1:
self.interface.initialize_function_spaces(o, self)
self.initialize_function_spaces(o)
index = None
if isinstance(o.ufl_element(), MixedElement):
......@@ -248,7 +248,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
self.indices = None
return self.call(o.ufl_operands[index])
else:
return self.interface.pymbolic_list_tensor(o)
raise CodegenUFLError("Index should have been unrolled!")
def component_tensor(self, o):
assert len(self.indices) == len(o.ufl_operands[1])
......@@ -265,9 +265,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return ret
def identity(self, o):
return self.interface.pymbolic_identity(o)
i, j = self.indices
assert isinstance(i, int) and isinstance(j, int)
return 1 if i == j else 0
def inverse(self, o):
from dune.codegen.pdelab.tensors import pymbolic_matrix_inverse
return self.interface.pymbolic_matrix_inverse(o)
#
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment