Skip to content
Snippets Groups Projects
Commit d7c558ab authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Add VCL Vector types for floating point dtypes to DuneTarget

parent 0e658790
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ from dune.perftool.generation import post_include ...@@ -2,6 +2,7 @@ from dune.perftool.generation import post_include
from dune.perftool.loopy.temporary import DuneTemporaryVariable from dune.perftool.loopy.temporary import DuneTemporaryVariable
from dune.perftool.pdelab.spaces import LFSLocalIndex from dune.perftool.pdelab.spaces import LFSLocalIndex
from dune.perftool.loopy.types import VCLTypeRegistry
from loopy.target import (TargetBase, from loopy.target import (TargetBase,
ASTBuilderBase, ASTBuilderBase,
...@@ -10,6 +11,7 @@ from loopy.target import (TargetBase, ...@@ -10,6 +11,7 @@ from loopy.target import (TargetBase,
from loopy.target.c import CASTBuilder from loopy.target.c import CASTBuilder
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper, CExpressionToCodeMapper from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper, CExpressionToCodeMapper
from loopy.symbolic import FunctionIdentifier from loopy.symbolic import FunctionIdentifier
from loopy.types import NumpyType
from pymbolic.primitives import Call, Subscript, Variable from pymbolic.primitives import Call, Subscript, Variable
...@@ -72,8 +74,14 @@ class DuneTarget(TargetBase): ...@@ -72,8 +74,14 @@ class DuneTarget(TargetBase):
return DuneASTBuilder(self) return DuneASTBuilder(self)
def dtype_to_typename(self, dtype): def dtype_to_typename(self, dtype):
# For now, we do this the simplest possible way if dtype.dtype.kind == "V":
return _registry[dtype.dtype.name] return VCLTypeRegistry.names[dtype.dtype]
else:
return _registry[dtype.dtype.name]
def is_vector_dtype(self, dtype): def is_vector_dtype(self, dtype):
return False return False
def vector_dtype(self, base, count):
return NumpyType(VCLTypeRegistry.types[base.numpy_dtype, count],
target=self)
"""
Our extensions to the loopy type system
"""
import numpy as np
class VCLTypeRegistry:
pass
def _populate_vcl_type_registry():
VCLTypeRegistry.types = {}
VCLTypeRegistry.names = {}
# The base types that we are working with!
for base_name, base_type, abbrev in [('float', np.float32, 'f'),
('double', np.float64, 'd'),
]:
# The vector width in bits we are considering!
for vector_bits in [128, 256, 512]:
# Calculate the vector lane width
count = vector_bits // (np.dtype(base_type).itemsize * 8)
# Define the name of this vector type
name = "Vec{}{}".format(count, abbrev)
# Construct the numpy dtype!
fieldnames = tuple("x" + str(i) for i in range(count))
dtype = np.dtype(dict(names=fieldnames,
formats=[base_type] * count,
)
)
VCLTypeRegistry.types[np.dtype(base_type), count] = dtype
VCLTypeRegistry.names[dtype] = name
_populate_vcl_type_registry()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment