From e4ca09625f3fb13f2e95bb1032592502b936f941 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 2 Jan 2018 13:18:51 +0100 Subject: [PATCH] First iteration on supporting single precision! --- dune/perftool/common/muladd_workarounds.hh | 12 + dune/perftool/common/timer_tsc.hh | 2 - dune/perftool/common/vectorclass.hh | 562 ++++++++++++++++++ dune/perftool/sumfact/transposereg.hh | 22 + python/dune/perftool/blockstructured/basis.py | 7 +- python/dune/perftool/generation/loopy.py | 11 +- python/dune/perftool/loopy/target.py | 17 +- python/dune/perftool/loopy/temporary.py | 12 +- .../loopy/transformations/vectorize_quad.py | 10 +- .../loopy/transformations/vectorview.py | 3 +- python/dune/perftool/loopy/vcl.py | 11 +- python/dune/perftool/options.py | 1 + python/dune/perftool/pdelab/argument.py | 7 +- .../dune/perftool/pdelab/driver/__init__.py | 2 + .../pdelab/driver/gridfunctionspace.py | 12 +- .../perftool/pdelab/driver/interpolate.py | 4 +- python/dune/perftool/pdelab/geometry.py | 16 +- python/dune/perftool/pdelab/parameter.py | 20 +- python/dune/perftool/pdelab/quadrature.py | 4 +- python/dune/perftool/pdelab/tensors.py | 2 - python/dune/perftool/sumfact/geometry.py | 6 +- python/dune/perftool/sumfact/quadrature.py | 5 +- python/dune/perftool/sumfact/realization.py | 6 +- python/dune/perftool/sumfact/symbolic.py | 8 +- python/dune/perftool/sumfact/tabulation.py | 10 +- python/dune/perftool/sumfact/vectorization.py | 7 +- python/dune/perftool/ufl/visitor.py | 2 +- 27 files changed, 699 insertions(+), 82 deletions(-) diff --git a/dune/perftool/common/muladd_workarounds.hh b/dune/perftool/common/muladd_workarounds.hh index 1a3e60c5..f05dab7c 100644 --- a/dune/perftool/common/muladd_workarounds.hh +++ b/dune/perftool/common/muladd_workarounds.hh @@ -14,6 +14,13 @@ inline double mul_add(double op1, double& op2, double op3) return op1 * op2 + op3; } + +inline float mul_add(float op1, float& op2, float op3) +{ + return op1 * op2 + op3; +} + + #ifdef ENABLE_COUNTER oc::OpCounter<double> mul_add(oc::OpCounter<double> op1, oc::OpCounter<double>& op2, oc::OpCounter<double> op3) @@ -21,6 +28,11 @@ oc::OpCounter<double> mul_add(oc::OpCounter<double> op1, oc::OpCounter<double>& return op1 * op2 + op3; } +oc::OpCounter<float> mul_add(oc::OpCounter<float> op1, oc::OpCounter<float>& op2, oc::OpCounter<float> op3) +{ + return op1 * op2 + op3; +} + #endif #endif diff --git a/dune/perftool/common/timer_tsc.hh b/dune/perftool/common/timer_tsc.hh index 449badc2..0f9e1e04 100644 --- a/dune/perftool/common/timer_tsc.hh +++ b/dune/perftool/common/timer_tsc.hh @@ -8,8 +8,6 @@ #include <dune/perftool/common/tsc.hh> #include <dune/perftool/common/opcounter.hh> -#define HP_TIMER_OPCOUNTER oc::OpCounter<double> - #define HP_TIMER_DURATION(name) __hp_timer_##name##_duration #define HP_TIMER_STARTTIME(name) __hp_timer_##name##_start #define HP_TIMER_OPCOUNTERS_START(name) __hp_timer_##name##_counters_start diff --git a/dune/perftool/common/vectorclass.hh b/dune/perftool/common/vectorclass.hh index bac161b2..0f5b9b88 100644 --- a/dune/perftool/common/vectorclass.hh +++ b/dune/perftool/common/vectorclass.hh @@ -1011,6 +1011,568 @@ static inline Vec8d blend8d(Vec8d const & a, Vec8d const & b) { #endif // MAC_VECTOR_SIZE >= 512 + +struct Vec8f +{ + oc::OpCounter<float> _d[8]; + + using F = oc::OpCounter<float>; + + Vec8f() + {} + + Vec8f(F d) + { + BARRIER; + std::fill(_d,_d+8,d); + BARRIER; + } + + Vec8f(double d) + { + BARRIER; + std::fill(_d,_d+8,d); + BARRIER; + } + + Vec8f(F d0, F d1, F d2, F d3, F d4, F d5, F d6, F d7) + : _d{d0,d1,d2,d3,d4,d5,d6,d7} + { + BARRIER; + } + + Vec8f& load(const F* p) + { + BARRIER; + std::copy(p,p+8,_d); + BARRIER; + return *this; + } + + Vec8f& load_a(const F* p) + { + BARRIER; + std::copy(p,p+8,_d); + BARRIER; + return *this; + } + + void store(F* p) const + { + BARRIER; + std::copy(_d,_d+8,p); + BARRIER; + } + + void store_a(F* p) const + { + BARRIER; + std::copy(_d,_d+8,p); + BARRIER; + } + + Vec8f const& insert(uint32_t index, F value) + { + BARRIER; + _d[index] = value; + BARRIER; + return *this; + } + + F extract(uint32_t index) const + { + BARRIER; + return _d[index]; + } + + constexpr static int size() + { + return 8; + } + +}; + + +/***************************************************************************** +* +* Operators for Vec4d +* +*****************************************************************************/ + +// vector operator + : add element by element +static inline Vec8f operator + (Vec8f const & a, Vec8f const & b) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,b._d,r._d,[](auto x, auto y){ return x + y; }); + BARRIER; + return r; +} + +// vector operator += : add +static inline Vec8f & operator += (Vec8f & a, Vec8f const & b) { + BARRIER; + std::transform(a._d,a._d+8,b._d,a._d,[](auto x, auto y){ return x + y; }); + BARRIER; + return a; +} + +// postfix operator ++ +static inline Vec8f operator ++ (Vec8f & a, int) { + BARRIER; + Vec8f a0 = a; + a = a + 1.0; + BARRIER; + return a0; +} + +// prefix operator ++ +static inline Vec8f & operator ++ (Vec8f & a) { + BARRIER; + a = a + 1.0; + BARRIER; + return a; +} + +// vector operator - : subtract element by element +static inline Vec8f operator - (Vec8f const & a, Vec8f const & b) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,b._d,r._d,[](auto x, auto y){ return x - y; }); + BARRIER; + return r; +} + +// vector operator - : unary minus +// Change sign bit, even for 0, INF and NAN +static inline Vec8f operator - (Vec8f const & a) { + BARRIER; + Vec8f r(a); + for (size_t i = 0 ; i < 8 ; ++i) + r._d[i] = -a._d[i]; + BARRIER; + return r; +} + +// vector operator -= : subtract +static inline Vec8f & operator -= (Vec8f & a, Vec8f const & b) { + BARRIER; + std::transform(a._d,a._d+8,b._d,a._d,[](auto x, auto y){ return x - y; }); + BARRIER; + return a; +} + +// postfix operator -- +static inline Vec8f operator -- (Vec8f & a, int) { + BARRIER; + Vec8f a0 = a; + a = a - 1.0; + BARRIER; + return a0; +} + +// prefix operator -- +static inline Vec8f & operator -- (Vec8f & a) { + BARRIER; + a = a - 1.0; + BARRIER; + return a; +} + +// vector operator * : multiply element by element +static inline Vec8f operator * (Vec8f const & a, Vec8f const & b) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,b._d,r._d,[](auto x, auto y){ return x * y; }); + BARRIER; + return r; +} + +// vector operator *= : multiply +static inline Vec8f & operator *= (Vec8f & a, Vec8f const & b) { + BARRIER; + std::transform(a._d,a._d+8,b._d,a._d,[](auto x, auto y){ return x * y; }); + BARRIER; + return a; +} + +// vector operator / : divide all elements by same integer +static inline Vec8f operator / (Vec8f const & a, Vec8f const & b) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,b._d,r._d,[](auto x, auto y){ return x / y; }); + BARRIER; + return r; +} + +// vector operator /= : divide +static inline Vec8f & operator /= (Vec8f & a, Vec8f const & b) { + BARRIER; + std::transform(a._d,a._d+8,b._d,a._d,[](auto x, auto y){ return x / y; }); + BARRIER; + return a; +} + +// vector operator == : returns true for elements for which a == b +static inline _vcl::Vec8fb operator == (Vec8f const & a, Vec8f const & b) { + BARRIER; + _vcl::Vec8f a_, b_; + BARRIER; + a_.load(a._d[0].data()); + BARRIER; + b_.load(b._d[0].data()); + BARRIER; + Vec8f::F::comparisons(8); + BARRIER; + return a_ == b_; +} + +// vector operator != : returns true for elements for which a != b +static inline _vcl::Vec8fb operator != (Vec8f const & a, Vec8f const & b) { + BARRIER; + _vcl::Vec8f a_, b_; + BARRIER; + a_.load(a._d[0].data()); + BARRIER; + b_.load(b._d[0].data()); + BARRIER; + Vec8f::F::comparisons(8); + BARRIER; + return a_ != b_; +} + +// vector operator < : returns true for elements for which a < b +static inline _vcl::Vec8fb operator < (Vec8f const & a, Vec8f const & b) { + BARRIER; + _vcl::Vec8f a_, b_; + BARRIER; + a_.load(a._d[0].data()); + BARRIER; + b_.load(b._d[0].data()); + BARRIER; + Vec8f::F::comparisons(8); + BARRIER; + return a_ < b_; +} + +// vector operator <= : returns true for elements for which a <= b +static inline _vcl::Vec8fb operator <= (Vec8f const & a, Vec8f const & b) { + BARRIER; + _vcl::Vec8f a_, b_; + BARRIER; + a_.load(a._d[0].data()); + BARRIER; + b_.load(b._d[0].data()); + BARRIER; + Vec8f::F::comparisons(8); + BARRIER; + return a_ <= b_; +} + +// vector operator > : returns true for elements for which a > b +static inline _vcl::Vec8fb operator > (Vec8f const & a, Vec8f const & b) { + return b < a; +} + +// vector operator >= : returns true for elements for which a >= b +static inline _vcl::Vec8fb operator >= (Vec8f const & a, Vec8f const & b) { + return b <= a; +} + +// avoid logical operators for now, I don't think we need them +#if 0 + +// Bitwise logical operators + +// vector operator & : bitwise and +static inline Vec8f operator & (Vec8f const & a, Vec8f const & b) { + return _mm256_and_pd(a, b); +} + +// vector operator &= : bitwise and +static inline Vec8f & operator &= (Vec8f & a, Vec8f const & b) { + a = a & b; + return a; +} + +// vector operator & : bitwise and of Vec8f and Vec8fb +static inline Vec8f operator & (Vec8f const & a, Vec8fb const & b) { + return _mm256_and_pd(a, b); +} +static inline Vec8f operator & (Vec8fb const & a, Vec8f const & b) { + return _mm256_and_pd(a, b); +} + +// vector operator | : bitwise or +static inline Vec8f operator | (Vec8f const & a, Vec8f const & b) { + return _mm256_or_pd(a, b); +} + +// vector operator |= : bitwise or +static inline Vec8f & operator |= (Vec8f & a, Vec8f const & b) { + a = a | b; + return a; +} + +// vector operator ^ : bitwise xor +static inline Vec8f operator ^ (Vec8f const & a, Vec8f const & b) { + return _mm256_xor_pd(a, b); +} + +// vector operator ^= : bitwise xor +static inline Vec8f & operator ^= (Vec8f & a, Vec8f const & b) { + a = a ^ b; + return a; +} + +// vector operator ! : logical not. Returns Boolean vector +static inline Vec8fb operator ! (Vec8f const & a) { + return a == Vec8f(0.0); +} + +#endif + + +// General arithmetic functions, etc. + +// Horizontal add: Calculates the sum of all vector elements. +static inline Vec8f::F horizontal_add (Vec8f const & a) { + BARRIER; + return std::accumulate(a._d,a._d+8,Vec8f::F(0.0)); + BARRIER; +} + +// function max: a > b ? a : b +static inline Vec8f max(Vec8f const & a, Vec8f const & b) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,b._d,r._d,[](auto x, auto y){ return max(x,y); }); + BARRIER; + return r; +} + +// function min: a < b ? a : b +static inline Vec8f min(Vec8f const & a, Vec8f const & b) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,b._d,r._d,[](auto x, auto y){ return min(x,y); }); + BARRIER; + return r; +} + +// function abs: absolute value +// Removes sign bit, even for -0.0f, -INF and -NAN +static inline Vec8f abs(Vec8f const & a) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,r._d,[](auto x){ return abs(x); }); + BARRIER; + return r; +} + +// function sqrt: square root +static inline Vec8f sqrt(Vec8f const & a) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,r._d,[](auto x){ return sqrt(x); }); + BARRIER; + return r; +} + +// function square: a * a +static inline Vec8f square(Vec8f const & a) { + return a * a; +} + + +// exponential function +static inline Vec8f exp(Vec8f const & a){ + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,r._d,[](auto x){ return exp(x); }); + BARRIER; + return r; +} + + +// ignore pow() for now +#if 0 + +// pow(Vec8f, int): +template <typename TT> static Vec8f pow(Vec8f const & a, TT n); + +// Raise floating point numbers to integer power n +template <> +inline Vec8f pow<int>(Vec8f const & x0, int n) { + return pow_template_i<Vec8f>(x0, n); +} + +// allow conversion from unsigned int +template <> +inline Vec8f pow<uint32_t>(Vec8f const & x0, uint32_t n) { + return pow_template_i<Vec8f>(x0, (int)n); +} + + +// Raise floating point numbers to integer power n, where n is a compile-time constant +template <int n> +static inline Vec8f pow_n(Vec8f const & a) { + if (n < 0) return Vec8f(1.0) / pow_n<-n>(a); + if (n == 0) return Vec8f(1.0); + if (n >= 256) return pow(a, n); + Vec8f x = a; // a^(2^i) + Vec8f y; // accumulator + const int lowest = n - (n & (n-1));// lowest set bit in n + if (n & 1) y = x; + if (n < 2) return y; + x = x*x; // x^2 + if (n & 2) { + if (lowest == 2) y = x; else y *= x; + } + if (n < 4) return y; + x = x*x; // x^4 + if (n & 4) { + if (lowest == 4) y = x; else y *= x; + } + if (n < 8) return y; + x = x*x; // x^8 + if (n & 8) { + if (lowest == 8) y = x; else y *= x; + } + if (n < 16) return y; + x = x*x; // x^16 + if (n & 16) { + if (lowest == 16) y = x; else y *= x; + } + if (n < 32) return y; + x = x*x; // x^32 + if (n & 32) { + if (lowest == 32) y = x; else y *= x; + } + if (n < 64) return y; + x = x*x; // x^64 + if (n & 64) { + if (lowest == 64) y = x; else y *= x; + } + if (n < 128) return y; + x = x*x; // x^128 + if (n & 128) { + if (lowest == 128) y = x; else y *= x; + } + return y; +} + +template <int n> +static inline Vec8f pow(Vec8f const & a, Const_int_t<n>) { + return pow_n<n>(a); +} + +#endif + +// function round: round to nearest integer (even). (result as double vector) +static inline Vec8f round(Vec8f const & a) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,r._d,[](auto x){ return round(x); }); + BARRIER; + return r; +} + +// function truncate: round towards zero. (result as double vector) +static inline Vec8f truncate(Vec8f const & a) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,r._d,[](auto x){ return trunc(x); }); + BARRIER; + return r; +} + +// function floor: round towards minus infinity. (result as double vector) +static inline Vec8f floor(Vec8f const & a) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,r._d,[](auto x){ return floor(x); }); + BARRIER; + return r; +} + +// function ceil: round towards plus infinity. (result as double vector) +static inline Vec8f ceil(Vec8f const & a) { + BARRIER; + Vec8f r; + std::transform(a._d,a._d+8,r._d,[](auto x){ return ceil(x); }); + BARRIER; + return r; +} + +#if 0 +// function round_to_int: round to nearest integer (even). (result as integer vector) +static inline Vec8i round_to_int(Vec8f const & a) { + // Note: assume MXCSR control register is set to rounding + return _mm256_cvtpd_epi32(a); +} + +// function truncate_to_int: round towards zero. (result as integer vector) +static inline Vec8i truncate_to_int(Vec8f const & a) { + return _mm256_cvttpd_epi32(a); +} +#endif + + +// Fused multiply and add functions + +// Multiply and add +static inline Vec8f mul_add(Vec8f const & a, Vec8f const & b, Vec8f const & c) { + BARRIER; + Vec8f r; + for (size_t i = 0 ; i < 8 ; ++i) + r._d[i] = a._d[i] * b._d[i] + c._d[i]; + BARRIER; + return r; +} + + +// Multiply and subtract +static inline Vec8f mul_sub(Vec8f const & a, Vec8f const & b, Vec8f const & c) { + BARRIER; + Vec8f r; + for (size_t i = 0 ; i < 8 ; ++i) + r._d[i] = a._d[i] * b._d[i] - c._d[i]; + BARRIER; + return r; +} + +// Multiply and inverse subtract +static inline Vec8f nmul_add(Vec8f const & a, Vec8f const & b, Vec8f const & c) { + BARRIER; + Vec8f r; + for (size_t i = 0 ; i < 8 ; ++i) + r._d[i] = - a._d[i] * b._d[i] + c._d[i]; + BARRIER; + return r; +} + + +template <int i0, int i1, int i2, int i3, int i4, int i5, int i6, int i7> +static inline Vec8f blend8f(Vec8f const & a, Vec8f const & b) { + BARRIER; + _vcl::Vec8f a_,b_; + BARRIER; + a_.load(a._d[0].data()); + BARRIER; + b_.load(b._d[0].data()); + BARRIER; + _vcl::Vec8f r_ = _vcl::blend8f<i0,i1,i2,i3,i4,i5,i6,i7>(a_,b_); + BARRIER; + Vec8f::F::blends(1); + BARRIER; + Vec8f r; + BARRIER; + r_.store(r._d[0].data()); + BARRIER; + return r; +} + + #endif // ENABLE_COUNTER #endif // DUNE_PDELAB_COMMON_VECTORCLASS_HH diff --git a/dune/perftool/sumfact/transposereg.hh b/dune/perftool/sumfact/transposereg.hh index 9d924bf9..f73c6a2f 100644 --- a/dune/perftool/sumfact/transposereg.hh +++ b/dune/perftool/sumfact/transposereg.hh @@ -27,6 +27,28 @@ void transpose_reg(Vec4d& a0, Vec4d& a1) a1 = b1; } +void transpose_reg(Vec8f& a0, Vec8f& a1, Vec8f& a2, Vec8f& a3) +{ + Vec8f b0, b1, b2, b3; + b0 = blend8f<0,1,8,9,2,3,10,11>(a0, a1); + b1 = blend8f<4,5,12,13,6,7,14,15>(a0, a1); + b2 = blend8f<0,1,8,9,2,3,10,11>(a2, a3); + b3 = blend8f<4,5,12,13,6,7,14,15>(a2, a3); + a0 = blend8f<0,1,2,3,8,9,10,11>(b0, b2); + a1 = blend8f<4,5,6,7,12,13,14,15>(b0, b2); + a2 = blend8f<0,1,2,3,8,9,10,11>(b1, b3); + a3 = blend8f<4,5,6,7,12,13,14,15>(b1, b3); +} + +void transpose_reg (Vec8f& a0, Vec8f& a1) +{ + Vec8f b0, b1; + b0 = blend8f<0,1,2,3,8,9,10,11>(a0, a1); + b1 = blend8f<4,5,6,7,12,13,14,15>(a0, a1); + a0 = b0; + a1 = b1; +} + #endif #if MAX_VECTOR_SIZE >= 512 diff --git a/python/dune/perftool/blockstructured/basis.py b/python/dune/perftool/blockstructured/basis.py index 77b266d7..2748a722 100644 --- a/python/dune/perftool/blockstructured/basis.py +++ b/python/dune/perftool/blockstructured/basis.py @@ -8,13 +8,14 @@ from dune.perftool.generation import (backend, initializer_list, include_file,) from dune.perftool.tools import get_pymbolic_basename +from dune.perftool.loopy.target import type_floatingpoint from dune.perftool.pdelab.basis import (declare_cache_temporary, name_localbasis_cache, type_localbasis, FEM_name_mangling) from dune.perftool.pdelab.driver import (isPk, - isQk) -from dune.perftool.pdelab.driver.gridfunctionspace import basetype_range + isQk, + ) from dune.perftool.pdelab.geometry import world_dimension from dune.perftool.pdelab.quadrature import pymbolic_quadrature_position_in_cell from dune.perftool.pdelab.spaces import type_leaf_gfs @@ -31,7 +32,7 @@ import pymbolic.primitives as prim @class_member(classtag="operator") def typedef_localbasis(element, name): df = "typename {}::Traits::GridView::ctype".format(type_leaf_gfs(element)) - r = basetype_range() + r = type_floatingpoint() dim = world_dimension() if isPk(element): if dim == 1: diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 96734233..a97df474 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -27,7 +27,8 @@ class DuneGlobalArg(lp.GlobalArg): def globalarg(name, shape=lp.auto, managed=True, **kw): if isinstance(shape, str): shape = (shape,) - dtype = kw.pop("dtype", np.float64) + from dune.perftool.loopy.target import dtype_floatingpoint + dtype = kw.pop("dtype", dtype_floatingpoint()) return DuneGlobalArg(name, dtype=dtype, shape=shape, managed=managed, **kw) @@ -37,7 +38,9 @@ def globalarg(name, shape=lp.auto, managed=True, **kw): def constantarg(name, shape=None, **kw): if isinstance(shape, str): shape = (shape,) - dtype = kw.pop("dtype", np.float64) + + from dune.perftool.loopy.target import dtype_floatingpoint + dtype = kw.pop("dtype", dtype_floatingpoint()) return lp.GlobalArg(name, dtype=dtype, shape=shape, **kw) @@ -45,7 +48,9 @@ def constantarg(name, shape=None, **kw): context_tags="kernel", cache_key_generator=lambda n, **kw: n) def valuearg(name, **kw): - return lp.ValueArg(name, **kw) + from dune.perftool.loopy.target import dtype_floatingpoint + dtype = kw.pop("dtype", dtype_floatingpoint()) + return lp.ValueArg(name, dtype=dtype, **kw) @generator_factory(item_tags=("domain",), context_tags="kernel") diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 96b7923a..5697d773 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -32,6 +32,16 @@ def _type_to_op_counter_type(name): return "oc::OpCounter<{}>".format(name) +def dtype_floatingpoint(): + bits = get_option("precision_bits") + if bits == 32: + return np.float32 + elif bits == 64: + return np.float64 + else: + raise NotImplementedError("{}bit floating point type".format(bits)) + + @pt.memoize def numpy_to_cpp_dtype(key): _registry = {'float32': 'float', @@ -47,6 +57,11 @@ def numpy_to_cpp_dtype(key): return _registry[key] +def type_floatingpoint(): + dtype = dtype_floatingpoint() + return numpy_to_cpp_dtype(NumpyType(dtype).dtype.name) + + class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): def map_subscript(self, expr, type_context): arr = self.find_array(expr) @@ -75,7 +90,7 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): ret = ExpressionToCExpressionMapper.map_constant(self, expr, type_context) if get_option('opcounter'): if type_context == "f": - _type = _type_to_op_counter_type('double') + _type = _type_to_op_counter_type('float') ret = Literal("{}({})".format(_type, ret.s)) if type_context == "d": _type = _type_to_op_counter_type('double') diff --git a/python/dune/perftool/loopy/temporary.py b/python/dune/perftool/loopy/temporary.py index b4b5d5d1..d916f6b0 100644 --- a/python/dune/perftool/loopy/temporary.py +++ b/python/dune/perftool/loopy/temporary.py @@ -9,13 +9,13 @@ import numpy def _temporary_type(shape_impl, shape, first=True): - from dune.perftool.loopy.target import numpy_to_cpp_dtype + from dune.perftool.loopy.target import type_floatingpoint if len(shape_impl) == 0: - return numpy_to_cpp_dtype('float64') + return type_floatingpoint() if shape_impl[0] == 'arr': if not first or len(set(shape_impl)) != 1: raise PerftoolLoopyError("We do not allow mixing of C++ containers and plain C arrays, for reasons of mental sanity") - return numpy_to_cpp_dtype('float64') + return type_floatingpoint() if shape_impl[0] == 'vec': return "std::vector<{}>".format(_temporary_type(shape_impl[1:], shape[1:], first=False)) if shape_impl[0] == 'fv': @@ -23,7 +23,7 @@ def _temporary_type(shape_impl, shape, first=True): if shape_impl[0] == 'fm': # For now, no field matrices of weird stuff... assert len(shape) == 2 - _type = numpy_to_cpp_dtype('float64') + _type = type_floatingpoint() return "Dune::FieldMatrix<{}, {}, {}>".format(_type, shape[0], shape[1]) @@ -55,7 +55,9 @@ class DuneTemporaryVariable(TemporaryVariable): if shape_impl is not None: self.decl_method = default_declaration - kwargs.setdefault('dtype', numpy.float64) + + from dune.perftool.loopy.target import dtype_floatingpoint + kwargs.setdefault('dtype', dtype_floatingpoint()) self.custom_declaration = self.decl_method is not None diff --git a/python/dune/perftool/loopy/transformations/vectorize_quad.py b/python/dune/perftool/loopy/transformations/vectorize_quad.py index abc72940..40d2020b 100644 --- a/python/dune/perftool/loopy/transformations/vectorize_quad.py +++ b/python/dune/perftool/loopy/transformations/vectorize_quad.py @@ -5,6 +5,7 @@ from dune.perftool.generation import (function_mangler, include_file, loopy_class_member, ) +from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view, add_vector_view, @@ -51,7 +52,7 @@ def rotate_function_mangler(knl, func, arg_dtypes): # passing the vector registers as references and have them # changed. Loopy assumes this function to be read-only. include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile") - vcl = lp.types.NumpyType(get_vcl_type(np.float64, vector_width=func.horizontal * func.vertical)) + vcl = lp.types.NumpyType(get_vcl_type(dtype_floatingpoint(), vector_width=func.horizontal * func.vertical)) return lp.CallMangleInfo(func.name, (), (vcl,) * func.horizontal) @@ -103,7 +104,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix): # Determine the vector lane width # TODO infer the numpy type here - vec_size = get_vcl_type_size(np.float64) + vec_size = get_vcl_type_size(dtype_floatingpoint()) vector_indices = VectorIndices(suffix) # @@ -155,7 +156,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix): new_insns = [] size = product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames)) - vec_size = get_vcl_type_size(np.float64) + vec_size = get_vcl_type_size(dtype_floatingpoint()) size = ceildiv(size, vec_size) # Add an additional domain to the kernel @@ -232,9 +233,8 @@ def _vectorize_quadrature_loop(knl, inames, suffix): dim_tags="f,vec", potentially_vectorized=True, classtag="operator", - dtype=np.float64, ) - knl = knl.copy(args=knl.args + [lp.GlobalArg(name, shape=shape, dim_tags="c,vec", dtype=np.float64)]) + knl = knl.copy(args=knl.args + [lp.GlobalArg(name, shape=shape, dim_tags="c,vec", dtype=dtype_floatingpoint())]) replacemap[expr] = prim.Subscript(prim.Variable(name), (vector_indices.get(1), prim.Variable(vec_iname)), ) diff --git a/python/dune/perftool/loopy/transformations/vectorview.py b/python/dune/perftool/loopy/transformations/vectorview.py index 7fc3ff8f..1160c8bd 100644 --- a/python/dune/perftool/loopy/transformations/vectorview.py +++ b/python/dune/perftool/loopy/transformations/vectorview.py @@ -4,6 +4,7 @@ One being an ordinary array with proper shape and so on, and one being a an array of SIMD vectors """ +from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import get_vcl_type_size import loopy as lp @@ -74,7 +75,7 @@ def add_vector_view(knl, tmpname, pad_to=None, flatview=False): dim_tags=dim_tags, shape=shape, base_storage=bsname, - dtype=np.float64, + dtype=dtype_floatingpoint(), scope=lp.temp_var_scope.PRIVATE, ) diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index 34ab80bc..c7c1ffd4 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -83,13 +83,16 @@ def vcl_cast_mangler(knl, func, arg_dtypes): @function_mangler def vcl_function_mangler(knl, func, arg_dtypes): if func == "mul_add": - vcl = lp.types.NumpyType(get_vcl_type(np.float64)) + dtype = arg_dtypes[0] + vcl = lp.types.NumpyType(get_vcl_type(dtype)) return lp.CallMangleInfo("mul_add", (vcl,), (vcl, vcl, vcl)) if func == "select": - vcl = lp.types.NumpyType(get_vcl_type(np.float64)) + dtype = arg_dtypes[0] + vcl = lp.types.NumpyType(get_vcl_type(dtype)) return lp.CallMangleInfo("select", (vcl,), (vcl, vcl, vcl)) if func == "horizontal_add": - vcl = lp.types.NumpyType(get_vcl_type(np.float64)) - return lp.CallMangleInfo("horizontal_add", (lp.types.NumpyType(np.float64),), (vcl,)) + dtype = arg_dtypes[0] + vcl = lp.types.NumpyType(get_vcl_type(dtype)) + return lp.CallMangleInfo("horizontal_add", (lp.types.NumpyType(dtype),), (vcl,)) diff --git a/python/dune/perftool/options.py b/python/dune/perftool/options.py index 9ae717c5..9fe285b9 100644 --- a/python/dune/perftool/options.py +++ b/python/dune/perftool/options.py @@ -64,6 +64,7 @@ class PerftoolOptionsArray(ImmutableRecord): architecture = PerftoolOption(default="haswell", helpstr="The architecture to optimize for. Possible values: haswell|knl") grid_offset = PerftoolOption(default=False, helpstr="Set to true if you want a yasp grid where the lower left corner is not in the origin.") simplify = PerftoolOption(default=True, helpstr="Whether to simplify expressions using sympy") + precision_bits = PerftoolOption(default=64, helpstr="The number of bits for the floating point type") assure_statement_ordering = PerftoolOption(default=False, helpstr="Whether special care should be taken for a good statement ordering in sumfact kernels, runs into a loopy scheduler performance bug, but is necessary for production.") # Arguments that are mainly to be set by logic depending on other options diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py index 5c1eef4d..848d9d9a 100644 --- a/python/dune/perftool/pdelab/argument.py +++ b/python/dune/perftool/pdelab/argument.py @@ -14,6 +14,7 @@ from dune.perftool.generation import (domain, kernel_cached, backend ) +from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.pdelab.index import name_index from dune.perftool.pdelab.basis import (evaluate_coefficient, evaluate_coefficient_gradient, @@ -48,7 +49,7 @@ class CoefficientAccess(FunctionIdentifier): @function_mangler def coefficient_mangler(target, func, dtypes): if isinstance(func, CoefficientAccess): - return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(str), NumpyType(numpy.int32))) + return CallMangleInfo(func.name, (NumpyType(dtype_floatingpoint()),), (NumpyType(str), NumpyType(numpy.int32))) class PDELabAccumulationFunction(FunctionIdentifier): @@ -74,7 +75,7 @@ def accumulation_mangler(target, func, dtypes): (), (NumpyType(str), NumpyType(numpy.int32), - NumpyType(numpy.float64), + NumpyType(dtype_floatingpoint()), ) ) if func.rank == 2: @@ -84,7 +85,7 @@ def accumulation_mangler(target, func, dtypes): NumpyType(numpy.int32), NumpyType(str), NumpyType(numpy.int32), - NumpyType(numpy.float64), + NumpyType(dtype_floatingpoint()), ) ) diff --git a/python/dune/perftool/pdelab/driver/__init__.py b/python/dune/perftool/pdelab/driver/__init__.py index 46ced0b3..5471ce75 100644 --- a/python/dune/perftool/pdelab/driver/__init__.py +++ b/python/dune/perftool/pdelab/driver/__init__.py @@ -260,6 +260,8 @@ def generate_driver(formdatas, data): # In case of operator conunting we only assemble the matrix and evaluate the residual # assemble_matrix_timer() from dune.perftool.pdelab.driver.timings import apply_jacobian_timer, evaluate_residual_timer + from dune.perftool.loopy.target import type_floatingpoint + pre_include("#define HP_TIMER_OPCOUNTER {}".format(type_floatingpoint())) evaluate_residual_timer() apply_jacobian_timer() elif is_stationary(): diff --git a/python/dune/perftool/pdelab/driver/gridfunctionspace.py b/python/dune/perftool/pdelab/driver/gridfunctionspace.py index d0dac27d..c1b439ea 100644 --- a/python/dune/perftool/pdelab/driver/gridfunctionspace.py +++ b/python/dune/perftool/pdelab/driver/gridfunctionspace.py @@ -16,6 +16,7 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling, name_initree, preprocess_leaf_data, ) +from dune.perftool.loopy.target import type_floatingpoint from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement @@ -31,18 +32,9 @@ def type_domainfield(): return "DF" -def basetype_range(): - if get_option('opcounter'): - from dune.perftool.pdelab.driver.timings import setup_timer - setup_timer() - return "oc::OpCounter<double>" - else: - return "double" - - @preamble def typedef_range(name): - return "using {} = {};".format(name, basetype_range()) + return "using {} = {};".format(name, type_floatingpoint()) def type_range(): diff --git a/python/dune/perftool/pdelab/driver/interpolate.py b/python/dune/perftool/pdelab/driver/interpolate.py index f3cc2afb..ebbdbfd3 100644 --- a/python/dune/perftool/pdelab/driver/interpolate.py +++ b/python/dune/perftool/pdelab/driver/interpolate.py @@ -101,10 +101,10 @@ def define_boundary_lambda(name, boundary): if isinstance(boundary, (int, float)): return "auto {} = [&](const auto& x){{ return {}; }};".format(name, float(boundary)) elif isinstance(boundary, Expr): - from dune.perftool.loopy.target import numpy_to_cpp_dtype + from dune.perftool.loopy.target import type_floatingpoint from dune.perftool.pdelab.driver.visitor import ufl_to_code return "auto {} = [&](const auto& x){{ return ({}){}; }};".format(name, - numpy_to_cpp_dtype("float64"), + type_floatingpoint(), ufl_to_code(boundary)) else: raise NotImplementedError("What is this?") diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py index 03af2447..0bf2f914 100644 --- a/python/dune/perftool/pdelab/geometry.py +++ b/python/dune/perftool/pdelab/geometry.py @@ -16,6 +16,7 @@ from dune.perftool.generation import (backend, from dune.perftool.options import (get_option, option_switch, ) +from dune.perftool.loopy.target import dtype_floatingpoint, type_floatingpoint from dune.perftool.pdelab.quadrature import quadrature_preamble from dune.perftool.tools import get_pymbolic_basename from ufl.algorithms import MultiFunction @@ -270,7 +271,7 @@ def pymbolic_unit_outer_normal(): evaluate_unit_outer_normal(name) else: declare_normal(name, None, None) - globalarg(name, shape=(world_dimension(),), dtype=np.float64) + globalarg(name, shape=(world_dimension(),)) return prim.Variable(name) @@ -292,7 +293,8 @@ def pymbolic_unit_inner_normal(): def type_jacobian_inverse_transposed(restriction): if get_option('turn_off_diagonal_jacobian'): dim = world_dimension() - return "typename Dune::FieldMatrix<double,{},{}>".format(dim, dim) + ftype = type_floatingpoint() + return "typename Dune::FieldMatrix<{},{},{}>".format(ftype, dim, dim) else: geo = type_cell_geometry(restriction) return "typename {}::JacobianInverseTransposed".format(geo) @@ -330,7 +332,7 @@ def define_constant_jacobian_inverse_transposed(name): def name_constant_jacobian_inverse_transposed(restriction): name = "jit" dim = world_dimension() - globalarg(name, dtype=np.float64, shape=(dim, dim), managed=False) + globalarg(name, shape=(dim, dim), managed=False) define_constant_jacobian_inverse_transposed(name) return name @@ -383,7 +385,7 @@ def _define_constant_jacobian_determinant(name): @backend(interface="detjac", name="constant_transformation_matrix") def define_constant_jacobian_determinant(name): - valuearg(name, dtype=np.float64) + valuearg(name) _define_constant_jacobian_determinant(name) @@ -411,7 +413,7 @@ def define_facet_jacobian_determinant(name): geo = name_geometry() pos = name_localcenter() - valuearg(name, dtype=np.float64) + valuearg(name) return "auto {} = {}.integrationElement({});".format(name, geo, @@ -468,7 +470,7 @@ def to_global(local): @preamble def define_cell_volume(name, restriction): geo = name_cell_geometry(restriction) - valuearg(name, dtype=np.float64) + valuearg(name) return "auto {} = {}.volume();".format(name, geo) @@ -484,7 +486,7 @@ def pymbolic_cell_volume(restriction): @preamble def define_facet_area(name): geo = name_intersection_geometry() - valuearg(name, dtype=np.float64) + valuearg(name) return "auto {} = {}.volume();".format(name, geo) diff --git a/python/dune/perftool/pdelab/parameter.py b/python/dune/perftool/pdelab/parameter.py index cb3f0154..63d47bf0 100644 --- a/python/dune/perftool/pdelab/parameter.py +++ b/python/dune/perftool/pdelab/parameter.py @@ -20,6 +20,7 @@ from dune.perftool.cgen.clazz import AccessModifier from dune.perftool.pdelab.localoperator import (class_type_from_cache, localoperator_basename, ) +from dune.perftool.loopy.target import type_floatingpoint from loopy.match import Writes @@ -49,7 +50,8 @@ def name_paramclass(): @class_member(classtag="parameterclass") def define_time(name): initializer_list(name, ["0.0"], classtag="parameterclass") - return "double {};".format(name) + ftype = type_floatingpoint() + return "{} {};".format(ftype, name) def name_time(): @@ -66,11 +68,12 @@ def define_set_time_method(): def define_set_time_method_operator(): time_name = name_time() param = name_paramclass() - # TODO double? + ftype = type_floatingpoint() + result = ["// Set time in instationary case", - "void setTime (double t_)", + "void setTime ({} t_)".format(ftype), "{", - " Dune::PDELab::InstationaryLocalOperatorDefaultMethods<double>::setTime(t_);", + " Dune::PDELab::InstationaryLocalOperatorDefaultMethods<{}>::setTime(t_);".format(ftype), " {}.setTime(t_);".format(param), "}" ] @@ -81,9 +84,10 @@ def define_set_time_method_operator(): @class_member(classtag="parameterclass") def define_set_time_method_parameterclass(): time_name = name_time() - # TODO double? + ftype = type_floatingpoint() + result = ["// Set time in instationary case", - "void setTime (double t_)", + "void setTime ({} t_)".format(ftype), "{", " {} = t_;".format(time_name), "}" @@ -156,7 +160,7 @@ def evaluate_cellwise_constant_parameter_function(name, restriction): from dune.perftool.generation.loopy import valuearg import numpy - valuearg(name, dtype=numpy.float64) + valuearg(name) return 'auto {} = {}.{}({}, {});'.format(name, name_paramclass(), @@ -179,7 +183,7 @@ def evaluate_intersectionwise_constant_parameter_function(name): from dune.perftool.generation.loopy import valuearg import numpy - valuearg(name, dtype=numpy.float64) + valuearg(name) return 'auto {} = {}.{}({}, {});'.format(name, name_paramclass(), diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index 4f28b9c5..02e4a428 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -108,7 +108,7 @@ def name_quadrature_points(): order = quadrature_order() name = "qp_dim{}_order{}".format(dim, order) shape = (name_quadrature_bound(), dim) - globalarg(name, shape=shape, dtype=numpy.float64, managed=False) + globalarg(name, shape=shape, managed=False) define_quadrature_points(name) fill_quadrature_points_cache(name) return name @@ -175,7 +175,7 @@ def name_quadrature_weights(): # Quadrature weighs is a globar argument for loopy shape = name_quadrature_bound() - globalarg(name, shape=(shape,), dtype=numpy.float64) + globalarg(name, shape=(shape,)) return name diff --git a/python/dune/perftool/pdelab/tensors.py b/python/dune/perftool/pdelab/tensors.py index 2996ac5e..5e962b99 100644 --- a/python/dune/perftool/pdelab/tensors.py +++ b/python/dune/perftool/pdelab/tensors.py @@ -35,7 +35,6 @@ def pymbolic_list_tensor(expr, visitor): name = get_counted_variable("listtensor") temporary_variable(name, shape=expr.ufl_shape, - dtype=np.float64, managed=True, ) define_list_tensor(name, expr, visitor) @@ -64,7 +63,6 @@ def pymbolic_identity(expr): temporary_variable(name, shape=expr.ufl_shape, shape_impl=('fm',), - dtype=np.float64, ) define_identity(name, expr) return prim.Variable(name) diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py index 77bb0038..e17b7aac 100644 --- a/python/dune/perftool/sumfact/geometry.py +++ b/python/dune/perftool/sumfact/geometry.py @@ -130,14 +130,14 @@ def define_mesh_width_eval(name): def name_lowerleft_corner(): name = "lowerleft_corner" - globalarg(name, dtype=np.float64, shape=(world_dimension(),)) + globalarg(name, shape=(world_dimension(),)) define_corner(name, True) return name def name_meshwidth(): name = "meshwidth" - globalarg(name, dtype=np.float64, shape=(world_dimension(),)) + globalarg(name, shape=(world_dimension(),)) define_mesh_width(name) return name @@ -226,7 +226,7 @@ def pymbolic_constant_facet_jacobian_determinant(): name = "fdetjac" define_constant_facet_jacobian_determinant(name) - globalarg(name, dtype=np.float64, shape=(world_dimension(),)) + globalarg(name, shape=(world_dimension(),)) return prim.Subscript(prim.Variable(name), (facedir,)) diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py index d42296f6..8209a41f 100644 --- a/python/dune/perftool/sumfact/quadrature.py +++ b/python/dune/perftool/sumfact/quadrature.py @@ -20,6 +20,7 @@ from dune.perftool.pdelab.geometry import (local_dimension, ) from dune.perftool.options import get_option from dune.perftool.sumfact.switch import get_facedir +from dune.perftool.loopy.target import dtype_floatingpoint from loopy import CallMangleInfo from loopy.symbolic import FunctionIdentifier @@ -65,7 +66,7 @@ class BaseWeight(FunctionIdentifier): @function_mangler def base_weight_function_mangler(target, func, dtypes): if isinstance(func, BaseWeight): - return CallMangleInfo(func.name, (NumpyType(np.float64),), ()) + return CallMangleInfo(func.name, (NumpyType(dtype_floatingpoint()),), ()) def pymbolic_base_weight(): @@ -152,7 +153,6 @@ def quadrature_weight(visitor): # Add a class member loopy_class_member(name, - dtype=np.float64, shape=local_qps_per_dir, classtag="operator", dim_tags=",".join(["f"] * dim), @@ -209,7 +209,6 @@ def pymbolic_quadrature_position(index, visitor): name = "quad_points_dim{}_num{}_dir{}".format(dim, local_qps_per_dir_str, index) loopy_class_member(name, - dtype=np.float64, shape=local_qps_per_dir, classtag="operator", dim_tags=",".join(["f"] * dim), diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 891bbe23..703e3e06 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -28,6 +28,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, ) from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.accumulation import sumfact_iname +from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import ExplicitVCLCast from ufl import MixedElement @@ -146,14 +147,13 @@ def _realize_sum_factorization_kernel(sf): inp_shape = permute_backward(inp_shape, perm) globalarg(direct_input_arg, - dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags, offset=_dof_offset(sf.input.element, sf.input.element_index), ) alias_data_array(direct_input_arg, direct_input) if matrix.vectorized: - input_summand = prim.Call(ExplicitVCLCast(np.float64, vector_width=sf.vector_width), + input_summand = prim.Call(ExplicitVCLCast(dtype_floatingpoint(), vector_width=sf.vector_width), (prim.Subscript(prim.Variable(direct_input_arg), input_inames),)) else: @@ -223,7 +223,6 @@ def _realize_sum_factorization_kernel(sf): direct_output = "{}_access_comp{}".format(sf.accumvar, sf.test_element_index) if ft == 'residual' or ft == 'jacobian_apply': globalarg(direct_output, - dtype=np.float64, shape=output_shape, dim_tags=novec_ftags, offset=_dof_offset(sf.test_element, sf.test_element_index), @@ -244,7 +243,6 @@ def _realize_sum_factorization_kernel(sf): manual_strides = tuple("stride:{}".format(rowsize * product(output_shape[:i])) for i in range(sf.length)) dim_tags = "{},{}".format(novec_ftags, ",".join(manual_strides)) globalarg(direct_output, - dtype=np.float64, shape=other_shape + output_shape, offset=rowsize * _dof_offset(sf.test_element, sf.test_element_index) + _dof_offset(sf.trial_element, sf.trial_element_index), dim_tags=dim_tags, diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index ecffe144..babb432b 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -326,8 +326,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): input = product(mat.basis_size for mat in self.matrix_sequence) matrices = sum(mat.memory_traffic for mat in set(matrix_sequence)) - # TODO: this is a hard coded sizeof(double) - return (input + matrices) * 8 + fbytes = get_option("precision_bits") / 8 + return (input + matrices) * fbytes @property def operations(self): @@ -635,8 +635,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) dofs = product(mat.basis_size for mat in self.matrix_sequence) matrices = sum(mat.memory_traffic for mat in set(matrix_sequence)) - # TODO: this is a hard coded sizeof(double) - return (dofs + matrices) * 8 + fbytes = get_option("precision_bits") / 8 + return (dofs + matrices) * fbytes @property def operations(self): diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py index 0f052c75..836cfab3 100644 --- a/python/dune/perftool/sumfact/tabulation.py +++ b/python/dune/perftool/sumfact/tabulation.py @@ -21,6 +21,7 @@ from dune.perftool.generation import (class_member, valuearg ) from dune.perftool.loopy.buffer import get_buffer_temporary +from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import ExplicitVCLCast from dune.perftool.pdelab.localoperator import (name_domain_field, lop_template_range_field, @@ -195,7 +196,7 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase): # Check whether we can realize this by broadcasting the values of a simple tabulation if len(set(self.tabs)) == 1: theta = self.tabs[0].pymbolic(indices[:-1]) - return prim.Call(ExplicitVCLCast(np.float64, vector_width=len(self.tabs)), (theta,)) + return prim.Call(ExplicitVCLCast(dtype_floatingpoint(), vector_width=len(self.tabs)), (theta,)) abbrevs = tuple("{}x{}".format("d" if t.derivative else "", "s{}".format(t.slice_index) if t.slice_size is not None else "") @@ -219,7 +220,6 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase): member = loopy_class_member(name, classtag="operator", - dtype=np.float64, dim_tags="f,f,vec", shape=(self.rows, self.cols, self.width), potentially_vectorized=True, @@ -301,7 +301,6 @@ def basis_functions_per_direction(): def define_oned_quadrature_weights(name, bound): loopy_class_member(name, - dtype=np.float64, classtag="operator", shape=(bound,), ) @@ -315,7 +314,6 @@ def name_oned_quadrature_weights(bound): def define_oned_quadrature_points(name, bound): loopy_class_member(name, - dtype=np.float64, classtag="operator", shape=(bound,), ) @@ -400,7 +398,8 @@ class PolynomialLookup(FunctionIdentifier): @function_mangler def polynomial_lookup_mangler(target, func, dtypes): if isinstance(func, PolynomialLookup): - return CallMangleInfo(func.name, (NumpyType(np.float64),), (NumpyType(np.int32), NumpyType(np.float64))) + dtype = dtype_floatingpoint() + return CallMangleInfo(func.name, (dtype,), (NumpyType(np.int32), NumpyType(dtype))) def define_theta(name, tabmat, additional_indices=(), width=None): @@ -421,7 +420,6 @@ def define_theta(name, tabmat, additional_indices=(), width=None): shape = shape + (width,) loopy_class_member(name, - dtype=np.float64, shape=shape, classtag="operator", dim_tags=dim_tags, diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 2ba21f52..f3fb96b9 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -2,6 +2,7 @@ import logging +from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import get_vcl_type_size from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel from dune.perftool.generation import (backend, @@ -62,7 +63,7 @@ def costmodel(sf): # Penalize scalar sum factorization kernels scalar_penalty = 1 if isinstance(sf, SumfactKernel): - scalar_penalty = get_vcl_type_size(np.float64) + scalar_penalty = get_vcl_type_size(dtype_floatingpoint()) # Return total operations return sf.operations * position_penalty_factor(sf) * vertical_penalty * scalar_penalty @@ -71,7 +72,7 @@ def costmodel(sf): @backend(interface="vectorization_strategy", name="explicit") def explicit_costfunction(sf): # Read the explicitly set values for horizontal and vertical vectorization - width = get_vcl_type_size(np.float64) + width = get_vcl_type_size(dtype_floatingpoint()) horizontal = get_option("vectorization_horizontal") if horizontal is None: horizontal = width @@ -157,7 +158,7 @@ def decide_vectorization_strategy(): .format(len(active_sumfacts))) # Find the best vectorization strategy by using a costmodel - width = get_vcl_type_size(np.float64) + width = get_vcl_type_size(dtype_floatingpoint()) # # Optimize over all the possible quadrature point tuples diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index 65dee899..fcf2ee77 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -164,7 +164,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): param = name_paramclass() time = name_time() name = param + "." + time - valuearg(name, dtype=np.float64) + valuearg(name) return Variable(name) # Check if this is a parameter function -- GitLab