[TritonMLIR] Disallow 0D tensor (#788)
This commit is contained in:
@@ -108,8 +108,7 @@ def TT_LoadOp : TT_Op<"load",
|
|||||||
AttrSizedOperandSegments,
|
AttrSizedOperandSegments,
|
||||||
MemoryEffects<[MemRead]>,
|
MemoryEffects<[MemRead]>,
|
||||||
TypesMatchWith<"infer ptr type from result type",
|
TypesMatchWith<"infer ptr type from result type",
|
||||||
"result", "ptr",
|
"result", "ptr", "getPointerTypeSameShape($_self)">,
|
||||||
"getPointerTypeFromTensor($_self)">,
|
|
||||||
TypesMatchWith<"infer mask type from result type or none",
|
TypesMatchWith<"infer mask type from result type or none",
|
||||||
"result", "mask", "getI1SameShape($_self)",
|
"result", "mask", "getI1SameShape($_self)",
|
||||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||||
@@ -118,7 +117,7 @@ def TT_LoadOp : TT_Op<"load",
|
|||||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||||
let summary = "load";
|
let summary = "load";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
|
||||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||||
BoolAttr:$isVolatile);
|
BoolAttr:$isVolatile);
|
||||||
|
|
||||||
@@ -147,13 +146,13 @@ def TT_StoreOp : TT_Op<"store",
|
|||||||
MemoryEffects<[MemWrite]>,
|
MemoryEffects<[MemWrite]>,
|
||||||
TypesMatchWith<"infer ptr type from value type",
|
TypesMatchWith<"infer ptr type from value type",
|
||||||
"value", "ptr",
|
"value", "ptr",
|
||||||
"getPointerTypeFromTensor($_self)">,
|
"getPointerTypeSameShape($_self)">,
|
||||||
TypesMatchWith<"infer mask type from value type",
|
TypesMatchWith<"infer mask type from value type",
|
||||||
"value", "mask", "getI1SameShape($_self)",
|
"value", "mask", "getI1SameShape($_self)",
|
||||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||||
let summary = "store";
|
let summary = "store";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
|
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||||
@@ -318,7 +317,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
|||||||
|
|
||||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||||
|
@@ -26,14 +26,15 @@ static Type getI32SameShape(Type type) {
|
|||||||
return i32Type;
|
return i32Type;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type getPointerTypeFromTensor(Type type) {
|
static Type getPointerTypeSameShape(Type type) {
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
Type elementType = tensorType.getElementType();
|
Type elementType = tensorType.getElementType();
|
||||||
auto shape = tensorType.getShape();
|
auto shape = tensorType.getShape();
|
||||||
PointerType ptrType = PointerType::get(elementType, 1);
|
PointerType ptrType = PointerType::get(elementType, 1);
|
||||||
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
|
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
|
||||||
|
} else {
|
||||||
|
return PointerType::get(type, 1);
|
||||||
}
|
}
|
||||||
return Type();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parser & printer for assembly forms
|
// Parser & printer for assembly forms
|
||||||
@@ -49,7 +50,7 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
|||||||
result.addTypes(resultTypes);
|
result.addTypes(resultTypes);
|
||||||
|
|
||||||
SmallVector<Type> operandTypes;
|
SmallVector<Type> operandTypes;
|
||||||
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
|
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
|
||||||
int hasMask = 0, hasOther = 0;
|
int hasMask = 0, hasOther = 0;
|
||||||
if (allOperands.size() >= 2) {
|
if (allOperands.size() >= 2) {
|
||||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||||
@@ -92,8 +93,8 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<Type> operandTypes;
|
SmallVector<Type> operandTypes;
|
||||||
operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr
|
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
|
||||||
operandTypes.push_back(valueType); // value
|
operandTypes.push_back(valueType); // value
|
||||||
if (allOperands.size() >= 3)
|
if (allOperands.size() >= 3)
|
||||||
operandTypes.push_back(getI1SameShape(valueType)); // mask
|
operandTypes.push_back(getI1SameShape(valueType)); // mask
|
||||||
|
|
||||||
@@ -194,26 +195,33 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
|||||||
// infer shape
|
// infer shape
|
||||||
Value arg = operands[0];
|
Value arg = operands[0];
|
||||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||||
|
auto argEltTy = argTy.getElementType();
|
||||||
auto retShape = argTy.getShape().vec();
|
auto retShape = argTy.getShape().vec();
|
||||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||||
retShape.erase(retShape.begin() + axis);
|
retShape.erase(retShape.begin() + axis);
|
||||||
// infer encoding
|
if (retShape.empty()) {
|
||||||
Attribute argEncoding = argTy.getEncoding();
|
// 0d-tensor -> scalar
|
||||||
Attribute retEncoding;
|
inferredReturnTypes.push_back(argEltTy);
|
||||||
if (argEncoding) {
|
} else {
|
||||||
Dialect &dialect = argEncoding.getDialect();
|
// nd-tensor where n >= 1
|
||||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
// infer encoding
|
||||||
if (inferLayoutInterface
|
Attribute argEncoding = argTy.getEncoding();
|
||||||
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
Attribute retEncoding;
|
||||||
.failed()) {
|
if (argEncoding) {
|
||||||
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
Dialect &dialect = argEncoding.getDialect();
|
||||||
return mlir::failure();
|
auto inferLayoutInterface =
|
||||||
|
dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||||
|
if (inferLayoutInterface
|
||||||
|
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
||||||
|
.failed()) {
|
||||||
|
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// create type
|
||||||
|
inferredReturnTypes.push_back(
|
||||||
|
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||||
}
|
}
|
||||||
// create type
|
|
||||||
auto argEltTy = argTy.getElementType();
|
|
||||||
inferredReturnTypes.push_back(
|
|
||||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1104,8 +1104,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||||
shape.erase(shape.begin() + axis);
|
shape.erase(shape.begin() + axis);
|
||||||
auto resType = mlir::RankedTensorType::get(
|
mlir::Type resType = inputTensorType.getElementType();
|
||||||
shape, inputTensorType.getElementType());
|
if (!shape.empty()) {
|
||||||
|
resType = mlir::RankedTensorType::get(
|
||||||
|
shape, inputTensorType.getElementType());
|
||||||
|
}
|
||||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||||
operand, axis);
|
operand, axis);
|
||||||
})
|
})
|
||||||
|
@@ -1,57 +0,0 @@
|
|||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: function with no arguments don't work
|
|
||||||
@triton.jit
|
|
||||||
def cast_check(X):
|
|
||||||
zero_0d = tl.zeros([], dtype=tl.float32)
|
|
||||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
|
||||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
|
||||||
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
|
||||||
|
|
||||||
# scalar + scalar -> scalar
|
|
||||||
a0 = 0.0 + 0.0
|
|
||||||
# scalar + 0D -> 0D
|
|
||||||
a1 = 0.0 + zero_0d
|
|
||||||
a2 = zero_0d + 0.0
|
|
||||||
# scalar + 1D -> 1D
|
|
||||||
a3 = 0.0 + zero_1d
|
|
||||||
a4 = zero_1d + 0.0
|
|
||||||
# scalar + 2D -> 2D
|
|
||||||
a5 = 0.0 + zero_2d_22
|
|
||||||
a6 = zero_2d_22 + 0.0
|
|
||||||
|
|
||||||
# 0D + 0D -> 0D
|
|
||||||
b1 = zero_0d + zero_0d
|
|
||||||
# 0D + 1D -> 1D
|
|
||||||
b2 = zero_0d + zero_1d
|
|
||||||
b3 = zero_1d + zero_0d
|
|
||||||
# 0D + 2D -> 2D
|
|
||||||
b4 = zero_0d + zero_2d_22
|
|
||||||
b5 = zero_2d_22 + zero_0d
|
|
||||||
|
|
||||||
# 1D + 1D -> 1D
|
|
||||||
c1 = zero_1d + zero_1d
|
|
||||||
# 1D + 2D -> 2D
|
|
||||||
c2 = zero_1d + zero_2d_21
|
|
||||||
c3 = zero_1d + zero_2d_22
|
|
||||||
c4 = zero_2d_21 + zero_1d
|
|
||||||
c5 = zero_2d_22 + zero_1d
|
|
||||||
|
|
||||||
# 2D + 2D -> 2D
|
|
||||||
d1 = zero_2d_21 + zero_2d_21
|
|
||||||
d2 = zero_2d_22 + zero_2d_22
|
|
||||||
d3 = zero_2d_21 + zero_2d_22
|
|
||||||
d4 = zero_2d_22 + zero_2d_21
|
|
||||||
|
|
||||||
return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
|
||||||
|
|
||||||
|
|
||||||
def test_cast_check():
|
|
||||||
kernel = triton.compiler._compile(cast_check,
|
|
||||||
signature="*fp32",
|
|
||||||
device=0,
|
|
||||||
output="ttgir")
|
|
||||||
assert (kernel)
|
|
||||||
# TODO: Check types of the results
|
|
80
python/tests/test_type.py
Normal file
80
python/tests/test_type.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: function with no arguments don't work
|
||||||
|
@triton.jit
|
||||||
|
def binop_type_check(X):
|
||||||
|
# 0d-tensor is not allowed.
|
||||||
|
# zero_0d = tl.zeros([], dtype=tl.float32)
|
||||||
|
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||||
|
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||||
|
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
||||||
|
|
||||||
|
# scalar + scalar -> scalar
|
||||||
|
a0 = 0.0 + 0.0
|
||||||
|
# # scalar + 0D -> 0D
|
||||||
|
# a1 = 0.0 + zero_0d
|
||||||
|
# a2 = zero_0d + 0.0
|
||||||
|
# scalar + 1D -> 1D
|
||||||
|
a3 = 0.0 + zero_1d
|
||||||
|
a4 = zero_1d + 0.0
|
||||||
|
# scalar + 2D -> 2D
|
||||||
|
a5 = 0.0 + zero_2d_22
|
||||||
|
a6 = zero_2d_22 + 0.0
|
||||||
|
|
||||||
|
# # 0D + 0D -> 0D
|
||||||
|
# b1 = zero_0d + zero_0d
|
||||||
|
# # 0D + 1D -> 1D
|
||||||
|
# b2 = zero_0d + zero_1d
|
||||||
|
# b3 = zero_1d + zero_0d
|
||||||
|
# # 0D + 2D -> 2D
|
||||||
|
# b4 = zero_0d + zero_2d_22
|
||||||
|
# b5 = zero_2d_22 + zero_0d
|
||||||
|
|
||||||
|
# 1D + 1D -> 1D
|
||||||
|
c1 = zero_1d + zero_1d
|
||||||
|
# 1D + 2D -> 2D
|
||||||
|
c2 = zero_1d + zero_2d_21
|
||||||
|
c3 = zero_1d + zero_2d_22
|
||||||
|
c4 = zero_2d_21 + zero_1d
|
||||||
|
c5 = zero_2d_22 + zero_1d
|
||||||
|
|
||||||
|
# 2D + 2D -> 2D
|
||||||
|
d1 = zero_2d_21 + zero_2d_21
|
||||||
|
d2 = zero_2d_22 + zero_2d_22
|
||||||
|
d3 = zero_2d_21 + zero_2d_22
|
||||||
|
d4 = zero_2d_22 + zero_2d_21
|
||||||
|
|
||||||
|
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||||
|
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||||
|
|
||||||
|
|
||||||
|
def test_binop_type_check():
|
||||||
|
kernel = triton.compiler._compile(binop_type_check,
|
||||||
|
signature="*fp32",
|
||||||
|
device=0,
|
||||||
|
output="ttgir")
|
||||||
|
assert (kernel)
|
||||||
|
# TODO: Check types of the results
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reduce_type_check(ptr):
|
||||||
|
v_32 = tl.load(ptr + tl.arange(0, 32))
|
||||||
|
v_scalar = tl.min(v_32, axis=0)
|
||||||
|
tl.store(ptr, v_scalar)
|
||||||
|
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
|
||||||
|
v_64 = tl.max(v_64x128, axis=1)
|
||||||
|
tl.store(ptr + tl.arange(0, 64), v_64)
|
||||||
|
v_128 = tl.max(v_64x128, axis=0)
|
||||||
|
tl.store(ptr + tl.arange(0, 128), v_128)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reduce_type_check():
|
||||||
|
kernel = triton.compiler._compile(reduce_type_check,
|
||||||
|
signature="*fp32",
|
||||||
|
device=0,
|
||||||
|
output="ttgir")
|
||||||
|
assert (kernel)
|
||||||
|
# TODO: Check types of the results
|
@@ -241,7 +241,9 @@ class block_type(dtype):
|
|||||||
# while tensor's shape is a list of constexpr.
|
# while tensor's shape is a list of constexpr.
|
||||||
|
|
||||||
# shape can be empty ([]) when an input is a 0D tensor.
|
# shape can be empty ([]) when an input is a 0D tensor.
|
||||||
if shape and isinstance(shape[0], constexpr):
|
if not shape:
|
||||||
|
raise TypeError('0d block_type is forbidden')
|
||||||
|
if isinstance(shape[0], constexpr):
|
||||||
shape = [s.value for s in shape]
|
shape = [s.value for s in shape]
|
||||||
|
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
@@ -991,7 +991,11 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
|||||||
for i, s in enumerate(shape):
|
for i, s in enumerate(shape):
|
||||||
if i != axis:
|
if i != axis:
|
||||||
ret_shape.append(s)
|
ret_shape.append(s)
|
||||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
if ret_shape:
|
||||||
|
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||||
|
else:
|
||||||
|
# 0d-tensor -> scalar
|
||||||
|
res_ty = scalar_ty
|
||||||
|
|
||||||
if scalar_ty.is_floating():
|
if scalar_ty.is_floating():
|
||||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||||
|
@@ -53,3 +53,80 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
|||||||
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
|
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
||||||
|
// Test if Load/Store ops can handle scalar values
|
||||||
|
%other = arith.constant 0.0e+0 : f32
|
||||||
|
|
||||||
|
// load scalar
|
||||||
|
// CHECK: %[[L0:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||||
|
%a = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||||
|
// CHECK: %[[L1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||||
|
%b = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||||
|
// CHECK: %[[L2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||||
|
%c = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||||
|
|
||||||
|
// store scalar
|
||||||
|
// CHECK: tt.store %{{.*}}, %[[L0]] : f32
|
||||||
|
tt.store %ptr, %a : f32
|
||||||
|
// CHECK: tt.store %{{.*}}, %[[L1]], %{{.*}} : f32
|
||||||
|
tt.store %ptr, %b, %mask : f32
|
||||||
|
// CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} : f32
|
||||||
|
tt.store %ptr, %c, %mask : f32
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||||
|
// Test if reduce ops infer types correctly
|
||||||
|
|
||||||
|
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
|
||||||
|
%a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
|
||||||
|
%b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
|
||||||
|
%c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
|
||||||
|
%e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
|
||||||
|
%f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
|
||||||
|
%g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32
|
||||||
|
|
||||||
|
// Avoid optimizations for c, e, and g
|
||||||
|
%ptr1x2 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x2x!tt.ptr<f32>>
|
||||||
|
%ptr1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x!tt.ptr<f32>>
|
||||||
|
tt.store %ptr1x2, %c : tensor<1x2xf32>
|
||||||
|
tt.store %ptr1, %e : tensor<1xf32>
|
||||||
|
tt.store %ptr, %g : f32
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||||
|
// Test if reduce ops infer types correctly
|
||||||
|
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
|
||||||
|
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
|
||||||
|
%v128x1 = tt.splat %v : (f32) -> tensor<128x1xf32>
|
||||||
|
%v1x128 = tt.splat %v : (f32) -> tensor<1x128xf32>
|
||||||
|
|
||||||
|
%zero128x128 = arith.constant dense<0.00e+00> : tensor<128x128xf32>
|
||||||
|
%zero32x32 = arith.constant dense<0.00e+00> : tensor<32x32xf32>
|
||||||
|
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
|
||||||
|
|
||||||
|
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||||
|
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
|
||||||
|
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||||
|
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||||
|
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
|
||||||
|
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||||
|
|
||||||
|
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||||
|
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>
|
||||||
|
%ptr1x1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x1x!tt.ptr<f32>>
|
||||||
|
tt.store %ptr128x128, %r1 : tensor<128x128xf32>
|
||||||
|
tt.store %ptr32x32, %r2 : tensor<32x32xf32>
|
||||||
|
tt.store %ptr128x128, %r3 : tensor<128x128xf32>
|
||||||
|
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user