Skip to content
Snippets Groups Projects
Commit 3c269b93 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Draft of sumfac trial function evaluation

parent f2d9e158
No related branches found
No related tags found
No related merge requests found
...@@ -29,5 +29,5 @@ def global_context(**kwargs): ...@@ -29,5 +29,5 @@ def global_context(**kwargs):
return _GlobalContext(**kwargs) return _GlobalContext(**kwargs)
def get_global_context_value(key): def get_global_context_value(key, default=None):
return _global_context_cache[key] return _global_context_cache.get(key, default)
from __future__ import print_function from __future__ import print_function
from functools import partial from functools import partial
from os import system
from dune.perftool.generation import global_context from dune.perftool.generation import global_context
from dune.perftool.loopy.transformations import get_loopy_transformations from dune.perftool.loopy.transformations import get_loopy_transformations
from dune.perftool.pdelab.localoperator import assembly_routine_signature, AssemblyMethod from dune.perftool.pdelab.localoperator import assembly_routine_signature, AssemblyMethod
import os
# Use the builtin 'input' in python2 and 'raw_input' in python3 # Use the builtin 'input' in python2 and 'raw_input' in python3
try: try:
...@@ -15,7 +16,7 @@ except: ...@@ -15,7 +16,7 @@ except:
def clear(): def clear():
system('cls' if os.name == 'nt' else 'clear') os.system('cls' if os.name == 'nt' else 'clear')
def kernel_name(v): def kernel_name(v):
......
...@@ -2,12 +2,21 @@ from dune.perftool.sumfact.quadrature import (quadrature_inames, ...@@ -2,12 +2,21 @@ from dune.perftool.sumfact.quadrature import (quadrature_inames,
quadrature_weight, quadrature_weight,
) )
from dune.perftool.sumfact.basis import pymbolic_trialfunction from dune.perftool.sumfact.basis import (lfs_inames,
pymbolic_basis,
pymbolic_trialfunction,
)
from dune.perftool.pdelab import PDELabInterface from dune.perftool.pdelab import PDELabInterface
class SumFactInterface(PDELabInterface): class SumFactInterface(PDELabInterface):
def lfs_inames(self, element, restriction, number=None, context=''):
return lfs_inames(element, restriction, number, context)
def pymbolic_basis(self, element, restriction, number):
return pymbolic_basis(element, restriction, number)
def pymbolic_trialfunction(self, element, restriction, component): def pymbolic_trialfunction(self, element, restriction, component):
return pymbolic_trialfunction(element, restriction, component) return pymbolic_trialfunction(element, restriction, component)
......
...@@ -226,13 +226,8 @@ def type_theta(): ...@@ -226,13 +226,8 @@ def type_theta():
@class_member("operator") @class_member("operator")
def define_theta(name, transpose): def define_theta(name, shape, transpose):
theta_type = type_theta() theta_type = type_theta()
if transpose:
shape = (basis_functions_per_direction(), quadrature_points_per_direction())
else:
shape = (quadrature_points_per_direction(), basis_functions_per_direction())
globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
initializer_list(name, [str(axis) for axis in shape], classtag="operator") initializer_list(name, [str(axis) for axis in shape], classtag="operator")
construct_theta(name, transpose) construct_theta(name, transpose)
return "{} {};".format(theta_type, name) return "{} {};".format(theta_type, name)
...@@ -240,11 +235,15 @@ def define_theta(name, transpose): ...@@ -240,11 +235,15 @@ def define_theta(name, transpose):
def name_theta(): def name_theta():
name = "Theta" name = "Theta"
define_theta(name, False) shape = (quadrature_points_per_direction(), basis_functions_per_direction())
globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
define_theta(name, shape, False)
return name return name
def name_theta_transposed(): def name_theta_transposed():
name = "ThetaT" name = "ThetaT"
define_theta(name, True) shape = (basis_functions_per_direction(), quadrature_points_per_direction())
globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
define_theta(name, shape, True)
return name return name
...@@ -5,6 +5,9 @@ multiplication withthe test function is part of the sum factorization kernel. ...@@ -5,6 +5,9 @@ multiplication withthe test function is part of the sum factorization kernel.
""" """
from dune.perftool.generation import (backend, from dune.perftool.generation import (backend,
cached, cached,
get_global_context_value,
instruction,
temporary_variable,
) )
from dune.perftool.sumfact.amatrix import (AMatrix, from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction, basis_functions_per_direction,
...@@ -13,9 +16,12 @@ from dune.perftool.sumfact.amatrix import (AMatrix, ...@@ -13,9 +16,12 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
) )
from dune.perftool.sumfact.sumfact import (setup_theta, from dune.perftool.sumfact.sumfact import (setup_theta,
sum_factorization_kernel, sum_factorization_kernel,
sumfact_iname,
) )
from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.loopy.buffer import initialize_buffer from dune.perftool.loopy.buffer import initialize_buffer
from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.restriction import restricted_name
from pytools import product from pytools import product
...@@ -42,23 +48,38 @@ def pymbolic_trialfunction(element, restriction, component): ...@@ -42,23 +48,38 @@ def pymbolic_trialfunction(element, restriction, component):
return p.Subscript(p.Variable(var), tuple(p.Variable(i) for i in quadrature_inames())) return p.Subscript(p.Variable(var), tuple(p.Variable(i) for i in quadrature_inames()))
def lfs_inames(leaf_element): @backend(interface="lfs_inames", name="sumfact")
return () def lfs_inames(element, restriction, number=1, context=''):
assert number == 1
formdata = get_global_context_value('formdata')
dim = formdata.geometric_dimension
return tuple(sumfact_iname(basis_functions_per_direction(),'lfsdim{}'.format(d)) for d in range(dim))
@backend(interface="evaluate_basis") @backend(interface="evaluate_basis")
@cached @cached
def evaluate_basis(leaf_element, name, restriction): def evaluate_basis(element, name, restriction):
temporary_variable(name, shape=()) temporary_variable(name, shape=())
theta = name_theta() theta = name_theta()
quad_inames = quadrature_inames() quad_inames = quadrature_inames()
lfs_inames = lfs_inames() inames = lfs_inames(element, restriction)
assert(len(quad_inames) == len(lfs_inames)) assert(len(quad_inames) == len(inames))
instruction(expression=p.Product(tuple(p.Subscript(p.Variable(theta), (p.Variable(i), p.Variable(j))) instruction(expression=p.Product(tuple(p.Subscript(p.Variable(theta), (p.Variable(i), p.Variable(j)))
for (i,j) in zip(quad_inames, lfs_inames)) for (i,j) in zip(quad_inames, inames))
), ),
assignee=p.Variable(name), assignee=p.Variable(name),
forced_iname_deps=frozenset(quad_inames + lfs_inames), forced_iname_deps=frozenset(quad_inames + inames),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
) )
def pymbolic_basis(element, restriction, number):
assert number == 1
assert element.num_sub_elements() == 0
name = "phi_{}".format(FEM_name_mangling(element))
name = restricted_name(name, restriction)
evaluate_basis(element, name, restriction)
return p.Variable(name)
\ No newline at end of file
from dune.perftool.pdelab.argument import (name_coefficientcontainer, from dune.perftool.pdelab.argument import (name_accumulation_variable,
name_coefficientcontainer,
pymbolic_coefficient, pymbolic_coefficient,
) )
from dune.perftool.generation import (backend, from dune.perftool.generation import (backend,
domain, domain,
function_mangler, function_mangler,
get_counter, get_counter,
get_global_context_value,
globalarg, globalarg,
iname, iname,
instruction, instruction,
...@@ -88,7 +88,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -88,7 +88,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
temporary_variable(name, shape=(quadrature_points_per_direction(),) * dim, managed=True) temporary_variable(name, shape=(quadrature_points_per_direction(),) * dim, managed=True)
instruction(assignee=Subscript(Variable(name), tuple(Variable(i) for i in quadrature_inames())), instruction(assignee=Subscript(Variable(name), tuple(Variable(i) for i in quadrature_inames())),
expression=pymbolic_expr, expression=pymbolic_expr,
forced_iname_deps=frozenset(quadrature_inames()), forced_iname_deps=frozenset(quadrature_inames() + visitor.inames),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
depends_on=frozenset({stage_insn(1)}), depends_on=frozenset({stage_insn(1)}),
) )
...@@ -105,7 +105,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -105,7 +105,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
num=2 num=2
) )
result = sum_factorization_kernel(a_matrices, "reffub", 2) result = sum_factorization_kernel(a_matrices, "reffub", 2, additional_inames=visitor.inames)
# Now write all this into the correct residual # Now write all this into the correct residual
lfs = name_lfs(accterm.argument.argexpr.ufl_element(), lfs = name_lfs(accterm.argument.argexpr.ufl_element(),
...@@ -113,21 +113,22 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -113,21 +113,22 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
accterm.argument.component, accterm.argument.component,
) )
inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices) inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
globalarg("r", accum = name_accumulation_variable()
globalarg(accum,
shape=(basis_functions_per_direction(),) * dim, shape=(basis_functions_per_direction(),) * dim,
argtype=AccumulationArg, argtype=AccumulationArg,
lfs=lfs, lfs=lfs,
) )
instruction(expression=Subscript(Variable(result), tuple(Variable(i) for i in inames)), instruction(expression=Subscript(Variable(result), tuple(Variable(i) for i in inames)),
assignee=Subscript(Variable('r'), tuple(Variable(i) for i in inames)), assignee=Subscript(Variable(accum), tuple(Variable(i) for i in inames)),
forced_iname_deps=frozenset(inames), forced_iname_deps=frozenset(inames + visitor.inames),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
depends_on=frozenset({stage_insn(3)}), depends_on=frozenset({stage_insn(3)}),
) )
def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({})): def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({}), additional_inames=frozenset({})):
""" """
Calculate a sum factorization matrix product. Calculate a sum factorization matrix product.
...@@ -183,7 +184,7 @@ def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({})): ...@@ -183,7 +184,7 @@ def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({})):
# at the same time store the instruction ID for the next instruction to depend on # at the same time store the instruction ID for the next instruction to depend on
insn_dep = frozenset({instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))), insn_dep = frozenset({instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))),
expression=Reduction("sum", k, prod), expression=Reduction("sum", k, prod),
forced_iname_deps=frozenset({i, j}), forced_iname_deps=frozenset({i, j}).union(additional_inames),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
depends_on=insn_dep.union(frozenset({stage_insn(stage)})), depends_on=insn_dep.union(frozenset({stage_insn(stage)})),
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment