[TritonIR] Make mask operand optional (#74)

This commit is contained in:
Shintaro Iwasaki
2022-08-22 22:00:17 -07:00
committed by GitHub
parent de2dd04c8a
commit 0ebef11c77
14 changed files with 113 additions and 102 deletions

View File

@@ -66,30 +66,32 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
// //
def TT_LoadOp : TT_Op<"load", def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape, [SameOperandsAndResultShape,
SameVariadicOperandSize,
MemoryEffects<[MemRead]>, MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type", TypesMatchWith<"infer ptr type from result type",
"result", "ptr", "result", "ptr",
"getPointerTypeFromTensor($_self)">, "getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from result type", TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "result", "mask", "getI1SameShape($_self)",
"getI1SameShape($_self)">, "($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from result type or none", TypesMatchWith<"infer other type from result type or none",
"result", "other", "result", "other", "$_self",
"$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> { "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load"; let summary = "load";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other, let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile); BoolAttr:$isVolatile);
let results = (outs TT_Type:$result); let results = (outs TT_Type:$result);
let builders = [ let builders = [
// for args with default values
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)> "triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
]; ];
let assemblyFormat = "operands attr-dict `:` type($result)"; let assemblyFormat = "operands attr-dict `:` type($result)";
@@ -102,17 +104,17 @@ def TT_StoreOp : TT_Op<"store",
"value", "ptr", "value", "ptr",
"getPointerTypeFromTensor($_self)">, "getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from value type", TypesMatchWith<"infer mask type from value type",
"value", "mask", "value", "mask", "getI1SameShape($_self)",
"getI1SameShape($_self)">]> { "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store"; let summary = "store";
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask); let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
let builders = [ let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>, OpBuilder<(ins "Value":$ptr, "Value":$value)>,
]; ];
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)"; let assemblyFormat = "operands attr-dict `:` type($value)";
} }
def TT_GEPOp : TT_Op<"getelementptr", def TT_GEPOp : TT_Op<"getelementptr",

View File

@@ -40,26 +40,27 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
def TTG_CopyAsyncOp : TTG_Op<"copy_async", def TTG_CopyAsyncOp : TTG_Op<"copy_async",
[MemoryEffects<[MemRead, MemWrite]>, [MemoryEffects<[MemRead, MemWrite]>,
SameVariadicOperandSize,
TypesMatchWith<"infer mask type from ptr type", TypesMatchWith<"infer mask type from ptr type",
"ptr", "mask", "ptr", "mask", "getI1SameShape($_self)",
"getI1SameShape($_self)">, "($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from ptr type", TypesMatchWith<"infer other type from ptr type",
"ptr", "other", "ptr", "other", "getPointeeType($_self)",
"getPointeeType($_self)">]> { "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "copy async"; let summary = "copy async";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other, let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile); BoolAttr:$isVolatile);
let builders = [ let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)> "triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
]; ];
let results = (outs TT_Type:$result); let results = (outs TT_Type:$result);
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)"; let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
// result needs to be of shared layout // result needs to be of shared layout
let verifier = [{ return ::verify(*this); }]; let verifier = [{ return ::verify(*this); }];

View File

@@ -226,10 +226,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
LogicalResult LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>( rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(), op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()); adaptor.evict(), adaptor.isVolatile());
return success(); return success();
} }
}; };

View File

@@ -49,58 +49,43 @@ namespace mlir {
namespace triton { namespace triton {
//-- StoreOp -- //-- StoreOp --
// Default mask
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) { ::mlir::Value ptr, ::mlir::Value value) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>(); StoreOp::build(builder, state, ptr, value, mlir::Value());
auto shape = ptrType.getShape();
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
state.addOperands(ptr);
state.addOperands(value);
state.addOperands(mask);
} }
//-- LoadOp -- //-- LoadOp --
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache, ::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) { ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>(); LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
Type elementType = isVolatile);
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
// mask
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
Type resultType = RankedTensorType::get(shape, elementType);
state.addOperands(ptr);
state.addOperands(mask);
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
} }
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value ptr, ::mlir::Value mask,
::mlir::triton::CacheModifier cache, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) { ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>(); TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType = Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType(); ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape(); auto shape = ptrType.getShape();
Type resultType = RankedTensorType::get(shape, elementType); Type resultType = RankedTensorType::get(shape, elementType);
state.addOperands(ptr); state.addOperands(ptr);
if (mask) {
state.addOperands(mask); state.addOperands(mask);
if (other) {
state.addOperands(other);
}
}
state.addAttribute( state.addAttribute(
cacheAttrName(state.name), cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache)); ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));

View File

@@ -52,6 +52,46 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
} // anonymous namespace } // anonymous namespace
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
{triton::LoadOp::getOperationName()}) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
if (!selectOp)
return mlir::failure();
mlir::Value trueValue = selectOp.getTrueValue();
mlir::Value falseValue = selectOp.getFalseValue();
auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
if (!loadOp)
return mlir::failure();
mlir::Value mask = loadOp.mask();
if (!mask)
return mlir::failure();
auto *broadcastOpCandidate = mask.getDefiningOp();
auto broadcastOp =
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
if (!broadcastOp)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());
return mlir::success();
}
};
#define GEN_PASS_CLASSES #define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc" #include "triton/Dialect/Triton/Transforms/Passes.h.inc"

View File

@@ -37,12 +37,6 @@ def CombineGEPPattern : Pat<
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1), (TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>; (TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
def CombineSelectMaskedLoadPattern : Pat<
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
// broadcast(cst) => cst // broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
def CombineBroadcastConstantPattern : Pat< def CombineBroadcastConstantPattern : Pat<

View File

@@ -391,6 +391,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
if (loads.contains(op->getResult(0))) { if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op); auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask(); Value mask = loadOp.mask();
if (mask) {
Value splatCond = builder.create<triton::SplatOp>( Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond); mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>( Value newMask = builder.create<arith::AndIOp>(
@@ -399,6 +400,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// once // once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask); nextMapping.map(mask, newMask);
}
// TODO: more elegant way to do this? // TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>( nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(), op->getLoc(), loadsMapping[op->getResult(0)].getType(),

View File

@@ -1442,21 +1442,9 @@ void init_triton_ir(py::module &&m) {
mlir::triton::EvictionPolicy evictionPolicy, mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value { bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = ptrType.getShape();
mlir::Type elementType = ptrType.getElementType()
.dyn_cast<mlir::triton::PointerType>()
.getPointeeType();
mlir::Type resultType =
mlir::RankedTensorType::get(shape, elementType);
if (other.has_value()) {
return self.create<mlir::triton::LoadOp>( return self.create<mlir::triton::LoadOp>(
loc, resultType, ptrs, mask, other.value(), cacheModifier, loc, ptrs, mask, other.value_or(mlir::Value()), cacheModifier,
evictionPolicy, isVolatile); evictionPolicy, isVolatile);
} else {
return self.create<mlir::triton::LoadOp>(
loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile);
}
}) })
.def("create_masked_store", .def("create_masked_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val, [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,

View File

@@ -47,6 +47,6 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>> %19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
tt.store %19, %20, %cst, : tensor<128x128xf32> tt.store %19, %20, %cst : tensor<128x128xf32>
return return
} }

View File

@@ -31,7 +31,7 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };", // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()> // CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32> tt.store %ptrs, %vs, %mask : tensor<128xf32>
return return
} }

View File

@@ -186,7 +186,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()> // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()> // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0> tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return return
} }
} }

View File

@@ -38,7 +38,7 @@ module {
} }
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>> %16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>> %17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
tt.store %17, %15#0, %6, : tensor<256xf32> tt.store %17, %15#0, %6 : tensor<256xf32>
return return
} }
} }
@@ -124,7 +124,7 @@ module {
// } // }
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> // tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
// return // return
// } // }
// } // }

View File

@@ -41,6 +41,6 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1> %18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1> %19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1> tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
return return
} }

View File

@@ -86,7 +86,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]> // CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> // CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store %22, %23, %cst_1, : tensor<64x64xf32, [[col_layout]]> // CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK: return // CHECK: return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
@@ -117,7 +117,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4> %24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4> %25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4> %26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
tt.store %24, %25, %26, : tensor<64x64xf32, #blocked4> tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
return return
} }
@@ -170,6 +170,6 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1> %20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1> %21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> %22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
tt.store %20, %21, %22, : tensor<64x64xf32, #blocked1> tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
return return
} }