""" The package that provides generating methods for all parts of the pdelab driver. Currently, these are hardcoded as strings. It would be possible to switch these to cgen expression. OTOH, there is not much to be gained there. NB: Previously this __init__.py was a module driver.py. As it was growing, we made it a package. Some content could and should be separated into new modules within this package! """ from dune.codegen.error import CodegenCodegenError from dune.codegen.generation import (generator_factory, get_global_context_value, global_context, include_file, cached, pre_include, preamble, ) from dune.codegen.options import (get_form_option, get_option, ) from ufl import TensorProductCell # # The following functions are not doing anything useful, but providing easy access # to quantities that are needed throughout the process of generating the driver! # def get_form_ident(): idents = [i.strip() for i in get_option("operators").split(",")] if len(idents) == 2: idents.remove("mass") assert(len(idents) == 1) return idents[0] def get_form(): data = get_global_context_value("data") form = get_form_option("form") if form is None: form = get_form_ident() return data.object_by_name[form] def get_dimension(): return get_form().ufl_cell().geometric_dimension() def get_cell(): return get_form().ufl_cell().cellname() def get_test_element(): return get_form().arguments()[0].ufl_element() def get_trial_element(): return get_form().coefficients()[0].ufl_element() def is_stationary(): return "mass" not in [i.strip() for i in get_option("operators").split(",")] def is_linear(form=None): '''Test if form is linear in trial function''' if form is None: form = get_form() from ufl import derivative from ufl.algorithms import expand_derivatives jacform = expand_derivatives(derivative(form, form.coefficients()[0])) for coeff in jacform.coefficients(): if 0 == coeff.count(): return False return True def isLagrange(fem): return fem._short_name is 'CG' def isSimplical(cell): if isinstance(cell, TensorProductCell): return False # Cells can be identified through strings *or* ufl objects if not isinstance(cell, str): cell = cell.cellname() return cell in ["vertex", "interval", "triangle", "tetrahedron"] def isQuadrilateral(cell): if isinstance(cell, TensorProductCell): return all(tuple(isSimplical(c) for c in cell.sub_cells())) # Cells can be identified through strings *or* ufl objects if not isinstance(cell, str): cell = cell.cellname() return cell in ["vertex", "interval", "quadrilateral", "hexahedron"] def isPk(fem): return isLagrange(fem) and isSimplical(fem.cell()) def isQk(fem): return isLagrange(fem) and isQuadrilateral(fem.cell()) def isDG(fem): return fem._short_name is 'DG' def FEM_name_mangling(fem): from ufl import MixedElement, VectorElement, FiniteElement, TensorElement, TensorProductElement if isinstance(fem, VectorElement): return FEM_name_mangling(fem.sub_elements()[0]) + "_" + str(fem.num_sub_elements()) if isinstance(fem, TensorElement): return FEM_name_mangling(fem.sub_elements()[0]) + "_" + "_".join(str(i) for i in fem.value_shape()) if isinstance(fem, MixedElement): name = "" for elem in fem.sub_elements(): if name is not "": name = name + "_" name = name + FEM_name_mangling(elem) return name if isinstance(fem, FiniteElement): return "{}{}".format(fem._short_name, fem.degree()) if isinstance(fem, TensorProductElement): assert(len(set(subel._short_name for subel in fem.sub_elements())) == 1) return "TP_{}".format("_".join(FEM_name_mangling(subel) for subel in fem.sub_elements())) raise NotImplementedError("FEM NAME MANGLING") def _flatten_list(l): if isinstance(l, (tuple, list)): for i in l: for ni in _flatten_list(i): yield ni else: yield l def _unroll_list_tensors(expr): from ufl.classes import ListTensor if isinstance(expr, ListTensor): for op in expr.ufl_operands: yield op else: yield expr def unroll_list_tensors(data): for expr in data: for e in _unroll_list_tensors(expr): yield e def preprocess_leaf_data(element, data, applyZeroDefault=True): data = get_global_context_value("data").object_by_name.get(data, None) if data is None and not applyZeroDefault: return None from ufl import MixedElement if isinstance(element, MixedElement): # data is None -> use 0 default if data is None: data = (0,) * element.value_size() # Flatten nested lists data = tuple(i for i in _flatten_list(data)) # Expand any list tensors data = tuple(i for i in unroll_list_tensors(data)) assert len(data) == element.value_size() return data else: # Do not return lists for non-MixedElement if not isinstance(data, (tuple, list)): return (data,) else: assert len(data) == 1 return data def name_inifile(): # TODO pass some other option here. return "argv[1]" @preamble(section="init") def parse_initree(varname): include_file("dune/common/parametertree.hh", filetag="driver") include_file("dune/common/parametertreeparser.hh", filetag="driver") filename = name_inifile() return ["Dune::ParameterTree initree;", "Dune::ParameterTreeParser::readINITree({}, {});".format(filename, varname)] def name_initree(): parse_initree("initree") # TODO we can get some other ini file here. return "initree" @preamble(section="init") def define_mpihelper(name): include_file("dune/common/parallel/mpihelper.hh", filetag="driver") if get_option("with_mpi"): return "Dune::MPIHelper& {} = Dune::MPIHelper::instance(argc, argv);".format(name) else: return "Dune::FakeMPIHelper& {} = Dune::FakeMPIHelper::instance(argc, argv);".format(name) def name_mpihelper(): name = "mpihelper" define_mpihelper(name) return name @preamble(section="grid") def check_parallel_execution(): from dune.codegen.pdelab.driver.gridfunctionspace import name_leafview gv = name_leafview() return ["if ({}.comm().size()==1){{".format(gv), ' std::cout << "This program should be run in parallel!" << std::endl;', " return 1;", "}"] def generate_driver(): # Guarantee that config.h is the very first include in the generated file include_file("config.h", filetag="driver") # Make sure that the MPI helper is instantiated name_mpihelper() # Add check to c++ file if this program should only be used in parallel mode if get_option("parallel"): check_parallel_execution() # Entrypoint for driver generation if get_option("opcounter") or get_option("performance_measuring"): if get_option("performance_measuring"): assert(not get_option("opcounter")) assert(isQuadrilateral(get_cell())) # In case of operator counting we only assemble the matrix and evaluate the residual # assemble_matrix_timer() from dune.codegen.pdelab.driver.timings import apply_jacobian_timer, evaluate_residual_timer from dune.codegen.loopy.target import type_floatingpoint pre_include("#define HP_TIMER_OPCOUNTER {}".format(type_floatingpoint()), filetag="driver") evaluate_residual_timer() if get_form_option("generate_jacobian_apply"): apply_jacobian_timer() elif is_stationary(): from dune.codegen.pdelab.driver.solve import dune_solve vec = dune_solve() from dune.codegen.pdelab.driver.vtk import vtkoutput vtkoutput() else: from dune.codegen.pdelab.driver.instationary import solve_instationary solve_instationary() from dune.codegen.pdelab.driver.error import compare_L2_squared if get_option("compare_l2errorsquared"): compare_L2_squared() # Make sure that timestream is declared before retrieving chache items if get_option("instrumentation_level") >= 1: from dune.codegen.pdelab.driver.timings import setup_timer setup_timer() from dune.codegen.pdelab.driver.error import return_statement return_statement() from dune.codegen.generation import retrieve_cache_items from cgen import FunctionDeclaration, FunctionBody, Block, Value, LineComment, Line, Generable driver_signature = FunctionDeclaration(Value('int', 'main'), [Value('int', 'argc'), Value('char**', 'argv')]) contents = [] # Assert that this program was called with ini file contents += ['if (argc != 2){', ' std::cerr << "This program needs to be called with an ini file" << std::endl;', ' return 1;', '}', ''] def add_section(tag, comment): tagcontents = [i for i in retrieve_cache_items("preamble and {}".format(tag), make_generable=True)] if tagcontents: contents.append(LineComment(comment)) contents.append(Line("\n")) contents.extend(tagcontents) contents.append(Line("\n")) add_section("init", "Initialize basic stuff...") if get_option("instrumentation_level") >= 1: init_contents = contents contents = [] add_section("grid", "Setup grid (view)...") add_section("fem", "Set up finite element maps...") add_section("gfs", "Set up grid function spaces...") add_section("constraints", "Set up constraints container...") add_section("gridoperator", "Set up grid grid operators...") add_section("vector", "Set up solution vectors...") add_section("timings", "Maybe take performance measurements...") add_section("solver", "Set up (non)linear solvers...") add_section("vtk", "Do visualization...") add_section("instat", "Set up instationary stuff...") add_section("printing", "Maybe print residuals and matrices to stdout...") add_section("error", "Maybe calculate errors for test results...") if get_option("instrumentation_level") >= 1: from dune.codegen.pdelab.driver.timings import timed_region contents = init_contents + timed_region('driver', contents) add_section("end", "Stuff that should happen at the end...") add_section("return_stmt", "Return statement...") contents.insert(0, "\n") driver_body = Block([c if isinstance(c, Generable) else Line(c + '\n') for c in contents]) # Wrap a try/catch block around the driver body from dune.codegen.cgen import CatchBlock, TryCatchBlock, Value, Block, Line catch_blocks = [CatchBlock(Value("Dune::Exception&", "e"), Block([Line("std::cerr << \"Dune reported error: \" << e << std::endl;\n"), Line("return 1;\n"), ]) ), CatchBlock(Value("std::exception&", "e"), Block([Line("std::cerr << \"Unknown exception thrown!\" << std::endl;\n"), Line("return 1;\n"), ]) ) ] driver_body = Block([TryCatchBlock(driver_body, catch_blocks)]) driver = FunctionBody(driver_signature, driver_body) filename = get_option("driver_file") from dune.codegen.file import generate_file generate_file(filename, "driver", [driver], headerguard=False) # Reset the caching data structure from dune.codegen.generation import delete_cache_items delete_cache_items()