[TritonIR] Make mask operand optional (#74)
This commit is contained in:
@@ -66,30 +66,32 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameVariadicOperandSize,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
TypesMatchWith<"infer mask type from result type",
|
||||
"result", "mask",
|
||||
"getI1SameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from result type or none",
|
||||
"result", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from result type or none",
|
||||
"result", "other",
|
||||
"$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> {
|
||||
"result", "other", "$_self",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "load";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
||||
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
@@ -102,17 +104,17 @@ def TT_StoreOp : TT_Op<"store",
|
||||
"value", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"value", "mask",
|
||||
"getI1SameShape($_self)">]> {
|
||||
"value", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
|
||||
let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr",
|
||||
|
@@ -40,26 +40,27 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
|
||||
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
[MemoryEffects<[MemRead, MemWrite]>,
|
||||
SameVariadicOperandSize,
|
||||
TypesMatchWith<"infer mask type from ptr type",
|
||||
"ptr", "mask",
|
||||
"getI1SameShape($_self)">,
|
||||
"ptr", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from ptr type",
|
||||
"ptr", "other",
|
||||
"getPointeeType($_self)">]> {
|
||||
"ptr", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "copy async";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
||||
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
|
||||
let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
|
||||
|
||||
// result needs to be of shared layout
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
@@ -226,10 +226,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
|
||||
op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(),
|
||||
adaptor.evict(), adaptor.isVolatile());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@@ -49,58 +49,43 @@ namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- StoreOp --
|
||||
// Default mask
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
auto shape = ptrType.getShape();
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(value);
|
||||
state.addOperands(mask);
|
||||
StoreOp::build(builder, state, ptr, value, mlir::Value());
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
// mask
|
||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(
|
||||
evictAttrName(state.name),
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
state.addTypes({resultType});
|
||||
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||
Type elementType =
|
||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
state.addOperands(ptr);
|
||||
state.addOperands(mask);
|
||||
if (mask) {
|
||||
state.addOperands(mask);
|
||||
if (other) {
|
||||
state.addOperands(other);
|
||||
}
|
||||
}
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
|
@@ -52,6 +52,46 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||
// => load(ptrs, broadcast(cond), other)
|
||||
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
|
||||
{triton::LoadOp::getOperationName()}) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
|
||||
if (!selectOp)
|
||||
return mlir::failure();
|
||||
|
||||
mlir::Value trueValue = selectOp.getTrueValue();
|
||||
mlir::Value falseValue = selectOp.getFalseValue();
|
||||
|
||||
auto *loadOpCandidate = trueValue.getDefiningOp();
|
||||
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
|
||||
if (!loadOp)
|
||||
return mlir::failure();
|
||||
|
||||
mlir::Value mask = loadOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto *broadcastOpCandidate = mask.getDefiningOp();
|
||||
auto broadcastOp =
|
||||
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
|
||||
if (!broadcastOp)
|
||||
return mlir::failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
|
@@ -37,12 +37,6 @@ def CombineGEPPattern : Pat<
|
||||
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
|
||||
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
|
||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||
// => load(ptrs, broadcast(cond), other)
|
||||
def CombineSelectMaskedLoadPattern : Pat<
|
||||
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
|
||||
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
|
||||
|
||||
// broadcast(cst) => cst
|
||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||
def CombineBroadcastConstantPattern : Pat<
|
||||
|
@@ -391,14 +391,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
if (mask) {
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
}
|
||||
// TODO: more elegant way to do this?
|
||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
|
@@ -1442,21 +1442,9 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::triton::EvictionPolicy evictionPolicy,
|
||||
bool isVolatile) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = ptrType.getShape();
|
||||
mlir::Type elementType = ptrType.getElementType()
|
||||
.dyn_cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
mlir::Type resultType =
|
||||
mlir::RankedTensorType::get(shape, elementType);
|
||||
if (other.has_value()) {
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, resultType, ptrs, mask, other.value(), cacheModifier,
|
||||
evictionPolicy, isVolatile);
|
||||
} else {
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile);
|
||||
}
|
||||
return self.create<mlir::triton::LoadOp>(
|
||||
loc, ptrs, mask, other.value_or(mlir::Value()), cacheModifier,
|
||||
evictionPolicy, isVolatile);
|
||||
})
|
||||
.def("create_masked_store",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
|
||||
|
@@ -47,6 +47,6 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
return
|
||||
}
|
@@ -31,7 +31,7 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
|
||||
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
||||
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
@@ -186,7 +186,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
||||
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@@ -38,7 +38,7 @@ module {
|
||||
}
|
||||
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
|
||||
tt.store %17, %15#0, %6, : tensor<256xf32>
|
||||
tt.store %17, %15#0, %6 : tensor<256xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -124,7 +124,7 @@ module {
|
||||
// }
|
||||
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
|
@@ -41,6 +41,6 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
|
||||
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
@@ -86,7 +86,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: tt.store %22, %23, %cst_1, : tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: return
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
@@ -117,7 +117,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
|
||||
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
|
||||
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
|
||||
tt.store %24, %25, %26, : tensor<64x64xf32, #blocked4>
|
||||
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,6 +170,6 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||
tt.store %20, %21, %22, : tensor<64x64xf32, #blocked1>
|
||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user