Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
dune-codegen
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
This is an archived project. Repository and other project resources are read-only.
Show more breadcrumbs
Christian Heinigk
dune-codegen
Commits
2b043378
Commit
2b043378
authored
7 years ago
by
Dominic Kempf
Browse files
Options
Downloads
Patches
Plain Diff
Move function_name generation onto symbolic representation
parent
61adf2f4
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
python/dune/perftool/sumfact/realization.py
+11
-38
11 additions, 38 deletions
python/dune/perftool/sumfact/realization.py
python/dune/perftool/sumfact/symbolic.py
+20
-11
20 additions, 11 deletions
python/dune/perftool/sumfact/symbolic.py
with
31 additions
and
49 deletions
python/dune/perftool/sumfact/realization.py
+
11
−
38
View file @
2b043378
...
@@ -29,6 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
...
@@ -29,6 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
permute_backward
,
permute_backward
,
permute_forward
,
permute_forward
,
)
)
from
dune.perftool.sumfact.quadrature
import
quadrature_points_per_direction
from
dune.perftool.sumfact.symbolic
import
(
get_input_output_tuple
,
from
dune.perftool.sumfact.symbolic
import
(
get_input_output_tuple
,
SumfactKernel
,
SumfactKernel
,
VectorizedSumfactKernel
,
VectorizedSumfactKernel
,
...
@@ -47,39 +48,9 @@ import numpy as np
...
@@ -47,39 +48,9 @@ import numpy as np
import
pymbolic.primitives
as
prim
import
pymbolic.primitives
as
prim
necessary_kernel_implementations
=
generator_factory
(
item_tags
=
(
"
kernelimpl
"
,),
no_deco
=
True
)
# Have a generator function store the necessary sum factorization kernel implementations
# This way then can easily be extracted at the end of the form visiting process
necessary_kernel_implementations
=
generator_factory
(
item_tags
=
(
"
kernelimpl
"
,),
cache_key_generator
=
lambda
a
:
a
[
0
].
function_name
,
no_deco
=
True
)
@generator_factory
(
cache_key_generator
=
lambda
s
,
qp
:
(
s
.
function_key
,
qp
))
def
_name_kernel_implementation_function
(
sf
,
qp
):
name
=
"
sfimpl_{}
"
.
format
(
"
_
"
.
join
(
str
(
m
)
for
m
in
sf
.
matrix_sequence
))
if
get_form_option
(
"
fastdg
"
):
if
sf
.
stage
==
1
:
if
isinstance
(
sf
,
SumfactKernel
):
fastdg
=
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
sf
.
input
.
element
),
sf
.
input
.
element_index
)
if
isinstance
(
sf
,
VectorizedSumfactKernel
):
fastdg
=
"
_
"
.
join
(
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
i
.
element
),
i
.
element_index
)
for
i
in
remove_duplicates
(
sf
.
input
.
inputs
))
if
sf
.
stage
==
3
:
if
isinstance
(
sf
,
SumfactKernel
):
fastdg
=
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
sf
.
output
.
test_element
),
sf
.
output
.
test_element_index
)
if
sf
.
within_inames
:
fastdg
=
"
{}x{}comp{}
"
.
format
(
fastdg
,
FEM_name_mangling
(
sf
.
output
.
trial_element
),
sf
.
output
.
trial_element_index
)
if
isinstance
(
sf
,
VectorizedSumfactKernel
):
fastdg
=
"
_
"
.
join
(
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
i
.
test_element
),
i
.
test_element_index
)
for
i
in
remove_duplicates
(
sf
.
output
.
outputs
))
if
sf
.
within_inames
:
fastdg
=
"
{}x{}
"
.
format
(
fastdg
,
"
_
"
.
join
(
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
i
.
trial_element
),
i
.
trial_element_index
)
for
i
in
remove_duplicates
(
sf
.
output
.
outputs
))
)
name
=
"
{}_fastdg{}_{}
"
.
format
(
name
,
sf
.
stage
,
fastdg
)
necessary_kernel_implementations
((
sf
,
qp
))
return
name
def
name_kernel_implementation_function
(
sf
):
from
dune.perftool.sumfact.quadrature
import
quadrature_points_per_direction
qp
=
quadrature_points_per_direction
()
return
_name_kernel_implementation_function
(
sf
,
qp
)
def
realize_sum_factorization_kernel
(
sf
,
**
kwargs
):
def
realize_sum_factorization_kernel
(
sf
,
**
kwargs
):
...
@@ -125,7 +96,6 @@ def _realize_sum_factorization_kernel(sf):
...
@@ -125,7 +96,6 @@ def _realize_sum_factorization_kernel(sf):
insn_dep
=
insn_dep
.
union
(
timer_dep
)
insn_dep
=
insn_dep
.
union
(
timer_dep
)
# Get all the necessary pieces for a function call
# Get all the necessary pieces for a function call
funcname
=
name_kernel_implementation_function
(
sf
)
buffers
=
tuple
(
name_buffer_storage
(
sf
.
buffer
,
i
)
for
i
in
range
(
2
))
buffers
=
tuple
(
name_buffer_storage
(
sf
.
buffer
,
i
)
for
i
in
range
(
2
))
# Make sure that the storage is allocated and has a certain minimum size
# Make sure that the storage is allocated and has a certain minimum size
...
@@ -153,8 +123,12 @@ def _realize_sum_factorization_kernel(sf):
...
@@ -153,8 +123,12 @@ def _realize_sum_factorization_kernel(sf):
if
sf
.
stage
==
3
:
if
sf
.
stage
==
3
:
fastdg_args
=
sf
.
output
.
fastdg_args
fastdg_args
=
sf
.
output
.
fastdg_args
# Trigger generation of the sum factorization kernel function
qp
=
quadrature_points_per_direction
()
necessary_kernel_implementations
((
sf
,
qp
))
# Call the function
# Call the function
code
=
"
{}({});
"
.
format
(
funcname
,
"
,
"
.
join
(
buffers
+
fastdg_args
))
code
=
"
{}({});
"
.
format
(
sf
.
func
tion_
name
,
"
,
"
.
join
(
buffers
+
fastdg_args
))
tag
=
"
sumfact_stage{}
"
.
format
(
sf
.
stage
)
tag
=
"
sumfact_stage{}
"
.
format
(
sf
.
stage
)
insn_dep
=
frozenset
({
instruction
(
code
=
code
,
insn_dep
=
frozenset
({
instruction
(
code
=
code
,
depends_on
=
insn_dep
,
depends_on
=
insn_dep
,
...
@@ -334,7 +308,6 @@ def realize_sumfact_kernel_function(sf):
...
@@ -334,7 +308,6 @@ def realize_sumfact_kernel_function(sf):
})
})
# Construct a loopy kernel object
# Construct a loopy kernel object
name
=
name_kernel_implementation_function
(
sf
)
from
dune.perftool.pdelab.localoperator
import
extract_kernel_from_cache
from
dune.perftool.pdelab.localoperator
import
extract_kernel_from_cache
args
=
[
"
const char* buffer0
"
,
"
const char* buffer1
"
]
args
=
[
"
const char* buffer0
"
,
"
const char* buffer1
"
]
if
get_form_option
(
'
fastdg
'
):
if
get_form_option
(
'
fastdg
'
):
...
@@ -344,7 +317,7 @@ def realize_sumfact_kernel_function(sf):
...
@@ -344,7 +317,7 @@ def realize_sumfact_kernel_function(sf):
if
sf
.
within_inames
:
if
sf
.
within_inames
:
args
.
append
(
"
unsigned int jacobian_offset{}
"
.
format
(
i
))
args
.
append
(
"
unsigned int jacobian_offset{}
"
.
format
(
i
))
signature
=
"
void {}({}) const
"
.
format
(
name
,
"
,
"
.
join
(
args
))
signature
=
"
void {}({}) const
"
.
format
(
sf
.
function_
name
,
"
,
"
.
join
(
args
))
kernel
=
extract_kernel_from_cache
(
"
kernel_default
"
,
name
,
[
signature
],
add_timings
=
False
)
kernel
=
extract_kernel_from_cache
(
"
kernel_default
"
,
sf
.
function_
name
,
[
signature
],
add_timings
=
False
)
delete_cache_items
(
"
kernel_default
"
)
delete_cache_items
(
"
kernel_default
"
)
return
kernel
return
kernel
This diff is collapsed.
Click to expand it.
python/dune/perftool/sumfact/symbolic.py
+
20
−
11
View file @
2b043378
...
@@ -5,6 +5,7 @@ from dune.perftool.generation import (get_counted_variable,
...
@@ -5,6 +5,7 @@ from dune.perftool.generation import (get_counted_variable,
subst_rule
,
subst_rule
,
transform
,
transform
,
)
)
from
dune.perftool.pdelab.driver
import
FEM_name_mangling
from
dune.perftool.pdelab.geometry
import
local_dimension
,
world_dimension
from
dune.perftool.pdelab.geometry
import
local_dimension
,
world_dimension
from
dune.perftool.sumfact.quadrature
import
quadrature_inames
from
dune.perftool.sumfact.quadrature
import
quadrature_inames
from
dune.perftool.sumfact.tabulation
import
BasisTabulationMatrixBase
,
BasisTabulationMatrixArray
from
dune.perftool.sumfact.tabulation
import
BasisTabulationMatrixBase
,
BasisTabulationMatrixArray
...
@@ -288,15 +289,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
...
@@ -288,15 +289,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
# Watch out for the documentation to see which key is used unter what circumstances
# Watch out for the documentation to see which key is used unter what circumstances
#
#
@property
@property
def
function_
key
(
self
):
def
function_
name
(
self
):
"""
Kernels sharing this key may use the same kernel implementation function
"""
"""
The name of the function that implements this kernel
"""
fastdg
=
(
)
name
=
"
sfimpl_{}
"
.
format
(
"
_
"
.
join
(
str
(
m
)
for
m
in
self
.
matrix_sequence
)
)
if
get_form_option
(
"
fastdg
"
):
if
get_form_option
(
"
fastdg
"
):
if
self
.
stage
==
1
:
if
self
.
stage
==
1
:
fastdg
=
(
self
.
input
.
element
,
self
.
input
.
element_index
)
fastdg
=
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
self
.
input
.
element
)
,
self
.
input
.
element_index
)
if
self
.
stage
==
3
:
if
self
.
stage
==
3
:
fastdg
=
(
self
.
output
.
test_element
,
self
.
output
.
test_element_index
,
self
.
output
.
trial_element
,
self
.
output
.
trial_element_index
)
fastdg
=
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
self
.
output
.
test_element
),
self
.
output
.
test_element_index
)
return
tuple
(
str
(
m
)
for
m
in
self
.
matrix_sequence
)
+
fastdg
if
self
.
within_inames
:
fastdg
=
"
{}x{}comp{}
"
.
format
(
fastdg
,
FEM_name_mangling
(
self
.
output
.
trial_element
),
self
.
output
.
trial_element_index
)
name
=
"
{}_fastdg{}_{}
"
.
format
(
name
,
self
.
stage
,
fastdg
)
return
name
@property
@property
def
parallel_key
(
self
):
def
parallel_key
(
self
):
...
@@ -556,14 +560,19 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
...
@@ -556,14 +560,19 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Watch out for the documentation to see which key is used unter what circumstances
# Watch out for the documentation to see which key is used unter what circumstances
#
#
@property
@property
def
function_
key
(
self
):
def
function_
name
(
self
):
fastdg
=
(
)
name
=
"
sfimpl_{}
"
.
format
(
"
_
"
.
join
(
str
(
m
)
for
m
in
self
.
matrix_sequence
)
)
if
get_form_option
(
"
fastdg
"
):
if
get_form_option
(
"
fastdg
"
):
if
self
.
stage
==
1
:
if
self
.
stage
==
1
:
fastdg
=
sum
((
(
i
.
element
,
i
.
element_index
)
for
i
in
remove_duplicates
(
self
.
input
.
inputs
))
,
())
fastdg
=
"
_
"
.
join
(
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
i
.
element
)
,
i
.
element_index
)
for
i
in
remove_duplicates
(
self
.
input
.
inputs
))
if
self
.
stage
==
3
:
if
self
.
stage
==
3
:
fastdg
=
sum
(((
o
.
test_element
,
o
.
test_element_index
,
o
.
trial_element
,
o
.
trial_element_index
)
for
o
in
remove_duplicates
(
self
.
output
.
outputs
)),
())
fastdg
=
"
_
"
.
join
(
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
i
.
test_element
),
i
.
test_element_index
)
for
i
in
remove_duplicates
(
self
.
output
.
outputs
))
return
tuple
(
str
(
m
)
for
m
in
self
.
matrix_sequence
)
+
fastdg
if
self
.
within_inames
:
fastdg
=
"
{}x{}
"
.
format
(
fastdg
,
"
_
"
.
join
(
"
{}comp{}
"
.
format
(
FEM_name_mangling
(
i
.
trial_element
),
i
.
trial_element_index
)
for
i
in
remove_duplicates
(
self
.
output
.
outputs
))
)
name
=
"
{}_fastdg{}_{}
"
.
format
(
name
,
self
.
stage
,
fastdg
)
return
name
@property
@property
def
cache_key
(
self
):
def
cache_key
(
self
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment