[TritonIR] make other optional and remove isOtherUnspecified (#67)
[Triton] make other optional and remove isOtherUnspecified
This commit is contained in:
@@ -73,14 +73,14 @@ def TT_LoadOp : TT_Op<"load",
|
|||||||
TypesMatchWith<"infer mask type from result type",
|
TypesMatchWith<"infer mask type from result type",
|
||||||
"result", "mask",
|
"result", "mask",
|
||||||
"getI1SameShape($_self)">,
|
"getI1SameShape($_self)">,
|
||||||
TypesMatchWith<"infer other type from result type",
|
TypesMatchWith<"infer other type from result type or none",
|
||||||
"result", "other",
|
"result", "other",
|
||||||
"$_self">]> {
|
"$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> {
|
||||||
let summary = "load";
|
let summary = "load";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
||||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||||
BoolAttr:$isVolatile, BoolAttr:$isOtherUnspecified);
|
BoolAttr:$isVolatile);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
@@ -89,11 +89,10 @@ def TT_LoadOp : TT_Op<"load",
|
|||||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile,
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||||
"bool":$isOtherUnspecified)>
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_StoreOp : TT_Op<"store",
|
def TT_StoreOp : TT_Op<"store",
|
||||||
|
@@ -48,10 +48,14 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
|||||||
"getPointeeType($_self)">]> {
|
"getPointeeType($_self)">]> {
|
||||||
let summary = "copy async";
|
let summary = "copy async";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
||||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||||
BoolAttr:$isVolatile,
|
BoolAttr:$isVolatile);
|
||||||
BoolAttr:$isOtherUnspecified);
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||||
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||||
|
];
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
|
@@ -229,8 +229,7 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
|||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||||
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile(),
|
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
|
||||||
adaptor.isOtherUnspecified());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -76,14 +76,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|||||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||||
DenseIntElementsAttr::get(
|
DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||||
// other
|
|
||||||
Type resultType = RankedTensorType::get(shape, elementType);
|
Type resultType = RankedTensorType::get(shape, elementType);
|
||||||
::mlir::Value other = builder.create<arith::ConstantOp>(
|
|
||||||
ptr.getLoc(), resultType,
|
|
||||||
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
|
|
||||||
state.addOperands(ptr);
|
state.addOperands(ptr);
|
||||||
state.addOperands(mask);
|
state.addOperands(mask);
|
||||||
state.addOperands(other);
|
|
||||||
state.addAttribute(
|
state.addAttribute(
|
||||||
cacheAttrName(state.name),
|
cacheAttrName(state.name),
|
||||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||||
@@ -92,8 +87,28 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|||||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||||
state.addAttribute(isVolatileAttrName(state.name),
|
state.addAttribute(isVolatileAttrName(state.name),
|
||||||
builder.getBoolAttr(isVolatile));
|
builder.getBoolAttr(isVolatile));
|
||||||
state.addAttribute(isOtherUnspecifiedAttrName(state.name),
|
state.addTypes({resultType});
|
||||||
builder.getBoolAttr(false));
|
}
|
||||||
|
|
||||||
|
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
|
::mlir::Value ptr, ::mlir::Value mask,
|
||||||
|
::mlir::triton::CacheModifier cache,
|
||||||
|
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
|
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||||
|
Type elementType =
|
||||||
|
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||||
|
auto shape = ptrType.getShape();
|
||||||
|
Type resultType = RankedTensorType::get(shape, elementType);
|
||||||
|
state.addOperands(ptr);
|
||||||
|
state.addOperands(mask);
|
||||||
|
state.addAttribute(
|
||||||
|
cacheAttrName(state.name),
|
||||||
|
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||||
|
state.addAttribute(
|
||||||
|
evictAttrName(state.name),
|
||||||
|
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||||
|
state.addAttribute(isVolatileAttrName(state.name),
|
||||||
|
builder.getBoolAttr(isVolatile));
|
||||||
state.addTypes({resultType});
|
state.addTypes({resultType});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -40,8 +40,8 @@ def CombineGEPPattern : Pat<
|
|||||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||||
// => load(ptrs, broadcast(cond), other)
|
// => load(ptrs, broadcast(cond), other)
|
||||||
def CombineSelectMaskedLoadPattern : Pat<
|
def CombineSelectMaskedLoadPattern : Pat<
|
||||||
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile, $isOtherUnspecified), $falseValue),
|
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
|
||||||
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile, $isOtherUnspecified)>;
|
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
|
||||||
|
|
||||||
// broadcast(cst) => cst
|
// broadcast(cst) => cst
|
||||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||||
|
@@ -240,7 +240,7 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||||
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
|
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
|
||||||
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
|
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
|
||||||
loadOp.isVolatile(), loadOp.isOtherUnspecified());
|
loadOp.isVolatile());
|
||||||
} else
|
} else
|
||||||
llvm_unreachable("This should be LoadOp");
|
llvm_unreachable("This should be LoadOp");
|
||||||
} else
|
} else
|
||||||
@@ -404,7 +404,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||||
loadOp.evict(), loadOp.isVolatile(), loadOp.isOtherUnspecified());
|
loadOp.evict(), loadOp.isVolatile());
|
||||||
} else
|
} else
|
||||||
nextOp = builder.clone(*op, nextMapping);
|
nextOp = builder.clone(*op, nextMapping);
|
||||||
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
|
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
|
||||||
|
@@ -1321,16 +1321,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
if (other.has_value()) {
|
if (other.has_value()) {
|
||||||
return self.create<mlir::triton::LoadOp>(
|
return self.create<mlir::triton::LoadOp>(
|
||||||
loc, resultType, ptrs, mask, other.value(), cacheModifier,
|
loc, resultType, ptrs, mask, other.value(), cacheModifier,
|
||||||
evictionPolicy, isVolatile, false);
|
evictionPolicy, isVolatile);
|
||||||
} else {
|
} else {
|
||||||
mlir::Value dummy_other = self.create<mlir::arith::ConstantOp>(
|
|
||||||
loc, resultType,
|
|
||||||
mlir::DenseElementsAttr::get(resultType,
|
|
||||||
self.getZeroAttr(elementType)));
|
|
||||||
return self.create<mlir::triton::LoadOp>(
|
return self.create<mlir::triton::LoadOp>(
|
||||||
loc, mlir::RankedTensorType::get(shape, elementType), ptrs,
|
loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile);
|
||||||
mask, dummy_other, cacheModifier, evictionPolicy, isVolatile,
|
|
||||||
true);
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.def("create_masked_store",
|
.def("create_masked_store",
|
||||||
|
@@ -46,7 +46,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
|||||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||||
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x128xf32>
|
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||||
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
@@ -21,10 +21,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK: offset = 0, size = 8192
|
// CHECK: offset = 0, size = 8192
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
// CHECK-NEXT: offset = 8192, size = 8192
|
// CHECK-NEXT: offset = 8192, size = 8192
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
|
|
||||||
@@ -49,17 +49,17 @@ func @reusable(%A : !tt.ptr<f16>) {
|
|||||||
|
|
||||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK: offset = 0, size = 8192
|
// CHECK: offset = 0, size = 8192
|
||||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL>
|
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
// CHECK-NEXT: offset = 8192, size = 8192
|
// CHECK-NEXT: offset = 8192, size = 8192
|
||||||
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
// CHECK-NEXT: offset = 16384, size = 8192
|
// CHECK-NEXT: offset = 16384, size = 8192
|
||||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL>
|
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
// CHECK-NEXT: offset = 0, size = 8192
|
// CHECK-NEXT: offset = 0, size = 8192
|
||||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
@@ -210,5 +210,5 @@ func @multi_blocks_noreuse(%i1 : i1) {
|
|||||||
// CHECK-NEXT: offset = 1024, size = 1024
|
// CHECK-NEXT: offset = 1024, size = 1024
|
||||||
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
return
|
return
|
||||||
// CHECK-NEXT: size = 3072
|
// CHECK-NEXT: size = 3072
|
||||||
}
|
}
|
||||||
|
@@ -21,7 +21,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0>
|
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK-SAME: ld.global.v4.b32
|
// CHECK-SAME: ld.global.v4.b32
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: ld.global.v4.b32
|
// CHECK-SAME: ld.global.v4.b32
|
||||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0>
|
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK-SAME: ld.global.v2.b32
|
// CHECK-SAME: ld.global.v2.b32
|
||||||
// CHECK: llvm.inline_asm
|
// CHECK: llvm.inline_asm
|
||||||
// CHECK-SAME: ld.global.v2.b32
|
// CHECK-SAME: ld.global.v2.b32
|
||||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf16, #blocked0>
|
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -64,7 +64,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK-LABEL: masked_load_const_other
|
// CHECK-LABEL: masked_load_const_other
|
||||||
func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
|
||||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||||
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32, #blocked0>
|
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -189,4 +189,4 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -49,12 +49,12 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
|
|||||||
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
||||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||||
|
|
||||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32>
|
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||||
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32>
|
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||||
%0 = select %cond, %x, %false_val : tensor<8xf32>
|
%0 = select %cond, %x, %false_val : tensor<8xf32>
|
||||||
|
|
||||||
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32>
|
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||||
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32>
|
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||||
%1 = select %cond, %y, %false_val : tensor<8xf32>
|
%1 = select %cond, %y, %false_val : tensor<8xf32>
|
||||||
|
|
||||||
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
||||||
|
@@ -24,10 +24,10 @@ module {
|
|||||||
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
|
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
|
||||||
%cst_0 = arith.constant 0.000000e+00 : f32
|
%cst_0 = arith.constant 0.000000e+00 : f32
|
||||||
%18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32>
|
%18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32>
|
||||||
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32>
|
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
|
||||||
%cst_1 = arith.constant 0.000000e+00 : f32
|
%cst_1 = arith.constant 0.000000e+00 : f32
|
||||||
%20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32>
|
%20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32>
|
||||||
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32>
|
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
|
||||||
%22 = arith.addf %19, %21 : tensor<256xf32>
|
%22 = arith.addf %19, %21 : tensor<256xf32>
|
||||||
%23 = arith.addf %arg7, %22 : tensor<256xf32>
|
%23 = arith.addf %arg7, %22 : tensor<256xf32>
|
||||||
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||||
|
@@ -40,7 +40,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|||||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||||
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
|
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
|
||||||
return
|
return
|
||||||
}
|
}
|
@@ -33,9 +33,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
|
|
||||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
@@ -75,9 +75,9 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
|||||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
|
|
||||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
@@ -106,7 +106,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
|
|||||||
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||||
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
||||||
|
|
||||||
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL>
|
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
|
|
||||||
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
||||||
@@ -116,7 +116,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
|
|||||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
|
|
||||||
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||||
|
Reference in New Issue
Block a user