Skip to content
Snippets Groups Projects
Commit 41beb839 authored by René Heß's avatar René Heß
Browse files

Generate driver for matrix assembly operator counting

parent 6e64a069
No related branches found
No related tags found
No related merge requests found
......@@ -40,7 +40,7 @@ def get_form_compiler_arguments():
parser.add_argument("--constant-transformation-matrix", action="store_true", help="set option if the jacobian of the transformation is constant on a cell")
parser.add_argument("--ini-file", type=str, help="An inifile to use. A generated driver will be hard-coded to it, a [formcompiler] section will be used as default values to form compiler arguments (use snake case)")
parser.add_argument("--timer", action="store_true", help="measure times")
parser.add_argument("--opcounter", action="store_false", help="Count operations. Should only be used with yaspgrid. Timer should be set.")
parser.add_argument("--opcounter", action="store_true", default=False, help="Count operations. Should only be used with yaspgrid. Timer should be set.")
parser.add_argument("--project-basedir", type=str, help="The base (build) directory of the dune-perftool project")
# TODO at some point this help description should be updated
parser.add_argument("--sumfact", action="store_true", help="Use sumfactorization")
......
......@@ -560,6 +560,12 @@ def define_dofestimate(name):
geo_factor = "6"
gfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element())
ini = name_initree()
# Assure that gfs in initialized
formdata = _driver_data['formdata']
x = name_vector(formdata)
define_vector(x, formdata)
return ["int generic_dof_estimate = {} * {}.maxLocalSize();".format(geo_factor, gfs),
"int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)]
......@@ -1107,21 +1113,15 @@ def dune_solve():
solve = "{}.apply();".format(snp)
if get_option('timer'):
# Necessary includes and defines
from dune.perftool.generation import pre_include
setup_timer()
from dune.perftool.generation import post_include
# TODO check that we are using YASP?
if get_option('opcounter'):
pre_include("#define ENABLE_COUNTER", filetag="driver")
pre_include("#define ENABLE_HP_TIMERS", filetag="driver")
include_file("dune/perftool/common/timer.hh", filetag="driver")
post_include("HP_DECLARE_TIMER(total);", filetag="driver")
# Print times after solving
from dune.perftool.generation import get_global_context_value
formdatas = get_global_context_value("formdatas")
print_times = []
define_exec()
for formdata in formdatas:
lop_name = name_localoperator(formdata)
timestream = name_timing_stream()
......@@ -1129,7 +1129,6 @@ def dune_solve():
solve = ["HP_TIMER_START(total);",
"{}".format(solve),
"HP_TIMER_STOP(total);",
"char* exec = argv[0];",
"DUMP_TIMER(total, {}, true);".format(timestream),
]
solve.extend(print_times)
......@@ -1331,6 +1330,57 @@ def print_residual():
"}"]
@cached
def setup_timer():
assert(get_option('timer'))
# Necessary includes and defines
from dune.perftool.generation import pre_include
# TODO check that we are using YASP?
if get_option('opcounter'):
pre_include("#define ENABLE_COUNTER", filetag="driver")
pre_include("#define ENABLE_HP_TIMERS", filetag="driver")
include_file("dune/perftool/common/timer.hh", filetag="driver")
@preamble
def define_exec():
return "char* exec = argv[0];"
@preamble
def assemble_matrix_timer():
formdata = _driver_data['formdata']
t_go = type_gridoperator(formdata)
n_go = name_gridoperator(formdata)
v = name_vector(formdata)
t_v = type_vector(formdata)
# Write back times
setup_timer()
from dune.perftool.generation import post_include
post_include("HP_DECLARE_TIMER(matrix_assembly);", filetag="driver")
timestream = name_timing_stream()
define_exec()
print_times = []
from dune.perftool.generation import get_global_context_value
formdatas = get_global_context_value("formdatas")
for formdata in formdatas:
lop_name = name_localoperator(formdata)
print_times.append("{}.dump_timers({}, argv[0], true);".format(lop_name, timestream))
assembly = ["using M = typename {}::Traits::Jacobian;".format(t_go),
"M m({});".format(n_go),
"HP_TIMER_START(matrix_assembly);",
"{}.jacobian({},m);".format(n_go, v),
"HP_TIMER_STOP(matrix_assembly);",
"DUMP_TIMER(matrix_assembly, {}, true);".format(timestream)]
assembly.extend(print_times)
return assembly
@preamble
def print_matrix():
formdata = _driver_data['formdata']
......@@ -1539,9 +1589,12 @@ def generate_driver(formdatas, data):
# The driver module uses a global dictionary for storing necessary data
set_driver_data(formdatas, data)
# The vtkoutput is the generating method that triggers all others.
# Alternatively, one could use the `solve` method.
if is_stationary():
# Entrypoint for driver generation
if get_option("opcounter"):
# In case of operator conunting we only assemble the matrix and evaluate the residual
assemble_matrix_timer()
elif is_stationary():
# We could also use solve if we are not interested in visualization
vtkoutput()
else:
solve_instationary()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment