[TritonIR] Make mask operand optional (#74)
This commit is contained in:
@@ -66,30 +66,32 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
|||||||
//
|
//
|
||||||
def TT_LoadOp : TT_Op<"load",
|
def TT_LoadOp : TT_Op<"load",
|
||||||
[SameOperandsAndResultShape,
|
[SameOperandsAndResultShape,
|
||||||
|
SameVariadicOperandSize,
|
||||||
MemoryEffects<[MemRead]>,
|
MemoryEffects<[MemRead]>,
|
||||||
TypesMatchWith<"infer ptr type from result type",
|
TypesMatchWith<"infer ptr type from result type",
|
||||||
"result", "ptr",
|
"result", "ptr",
|
||||||
"getPointerTypeFromTensor($_self)">,
|
"getPointerTypeFromTensor($_self)">,
|
||||||
TypesMatchWith<"infer mask type from result type",
|
TypesMatchWith<"infer mask type from result type or none",
|
||||||
"result", "mask",
|
"result", "mask", "getI1SameShape($_self)",
|
||||||
"getI1SameShape($_self)">,
|
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||||
TypesMatchWith<"infer other type from result type or none",
|
TypesMatchWith<"infer other type from result type or none",
|
||||||
"result", "other",
|
"result", "other", "$_self",
|
||||||
"$_self", "($_op.getOperands().size() == 2) || std::equal_to<>()">]> {
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||||
let summary = "load";
|
let summary = "load";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||||
BoolAttr:$isVolatile);
|
BoolAttr:$isVolatile);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
// for args with default values
|
|
||||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||||
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
||||||
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||||
];
|
];
|
||||||
|
|
||||||
let assemblyFormat = "operands attr-dict `:` type($result)";
|
let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||||
@@ -102,17 +104,17 @@ def TT_StoreOp : TT_Op<"store",
|
|||||||
"value", "ptr",
|
"value", "ptr",
|
||||||
"getPointerTypeFromTensor($_self)">,
|
"getPointerTypeFromTensor($_self)">,
|
||||||
TypesMatchWith<"infer mask type from value type",
|
TypesMatchWith<"infer mask type from value type",
|
||||||
"value", "mask",
|
"value", "mask", "getI1SameShape($_self)",
|
||||||
"getI1SameShape($_self)">]> {
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||||
let summary = "store";
|
let summary = "store";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, Optional<I1Tensor>:$mask);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||||
];
|
];
|
||||||
|
|
||||||
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
|
let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_GEPOp : TT_Op<"getelementptr",
|
def TT_GEPOp : TT_Op<"getelementptr",
|
||||||
|
@@ -40,26 +40,27 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
|||||||
|
|
||||||
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||||
[MemoryEffects<[MemRead, MemWrite]>,
|
[MemoryEffects<[MemRead, MemWrite]>,
|
||||||
|
SameVariadicOperandSize,
|
||||||
TypesMatchWith<"infer mask type from ptr type",
|
TypesMatchWith<"infer mask type from ptr type",
|
||||||
"ptr", "mask",
|
"ptr", "mask", "getI1SameShape($_self)",
|
||||||
"getI1SameShape($_self)">,
|
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||||
TypesMatchWith<"infer other type from ptr type",
|
TypesMatchWith<"infer other type from ptr type",
|
||||||
"ptr", "other",
|
"ptr", "other", "getPointeeType($_self)",
|
||||||
"getPointeeType($_self)">]> {
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||||
let summary = "copy async";
|
let summary = "copy async";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, Optional<TT_Type>:$other,
|
let arguments = (ins TT_PtrTensor:$ptr, Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||||
BoolAttr:$isVolatile);
|
BoolAttr:$isVolatile);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||||
];
|
];
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
|
let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
|
||||||
|
|
||||||
// result needs to be of shared layout
|
// result needs to be of shared layout
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
@@ -226,10 +226,9 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||||
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
op, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(),
|
||||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
|
adaptor.evict(), adaptor.isVolatile());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -49,58 +49,43 @@ namespace mlir {
|
|||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
//-- StoreOp --
|
//-- StoreOp --
|
||||||
// Default mask
|
|
||||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
::mlir::Value ptr, ::mlir::Value value) {
|
::mlir::Value ptr, ::mlir::Value value) {
|
||||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
StoreOp::build(builder, state, ptr, value, mlir::Value());
|
||||||
auto shape = ptrType.getShape();
|
|
||||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
|
||||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
|
||||||
DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
|
||||||
state.addOperands(ptr);
|
|
||||||
state.addOperands(value);
|
|
||||||
state.addOperands(mask);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//-- LoadOp --
|
//-- LoadOp --
|
||||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
|
||||||
Type elementType =
|
isVolatile);
|
||||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
|
||||||
auto shape = ptrType.getShape();
|
|
||||||
// mask
|
|
||||||
::mlir::Value mask = builder.create<arith::ConstantOp>(
|
|
||||||
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
|
|
||||||
DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get(shape, builder.getI1Type()), true));
|
|
||||||
Type resultType = RankedTensorType::get(shape, elementType);
|
|
||||||
state.addOperands(ptr);
|
|
||||||
state.addOperands(mask);
|
|
||||||
state.addAttribute(
|
|
||||||
cacheAttrName(state.name),
|
|
||||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
|
||||||
state.addAttribute(
|
|
||||||
evictAttrName(state.name),
|
|
||||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
|
||||||
state.addAttribute(isVolatileAttrName(state.name),
|
|
||||||
builder.getBoolAttr(isVolatile));
|
|
||||||
state.addTypes({resultType});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
::mlir::Value ptr, ::mlir::Value mask,
|
::mlir::Value ptr, ::mlir::Value mask,
|
||||||
::mlir::triton::CacheModifier cache,
|
::mlir::triton::CacheModifier cache,
|
||||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
|
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
|
||||||
|
isVolatile);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
|
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||||
|
::mlir::triton::CacheModifier cache,
|
||||||
|
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
|
||||||
Type elementType =
|
Type elementType =
|
||||||
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
|
||||||
auto shape = ptrType.getShape();
|
auto shape = ptrType.getShape();
|
||||||
Type resultType = RankedTensorType::get(shape, elementType);
|
Type resultType = RankedTensorType::get(shape, elementType);
|
||||||
state.addOperands(ptr);
|
state.addOperands(ptr);
|
||||||
|
if (mask) {
|
||||||
state.addOperands(mask);
|
state.addOperands(mask);
|
||||||
|
if (other) {
|
||||||
|
state.addOperands(other);
|
||||||
|
}
|
||||||
|
}
|
||||||
state.addAttribute(
|
state.addAttribute(
|
||||||
cacheAttrName(state.name),
|
cacheAttrName(state.name),
|
||||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||||
|
@@ -52,6 +52,46 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||||
|
// => load(ptrs, broadcast(cond), other)
|
||||||
|
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
|
||||||
|
public:
|
||||||
|
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
|
||||||
|
: mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
|
||||||
|
{triton::LoadOp::getOperationName()}) {}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
|
||||||
|
if (!selectOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
mlir::Value trueValue = selectOp.getTrueValue();
|
||||||
|
mlir::Value falseValue = selectOp.getFalseValue();
|
||||||
|
|
||||||
|
auto *loadOpCandidate = trueValue.getDefiningOp();
|
||||||
|
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
|
||||||
|
if (!loadOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
mlir::Value mask = loadOp.mask();
|
||||||
|
if (!mask)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto *broadcastOpCandidate = mask.getDefiningOp();
|
||||||
|
auto broadcastOp =
|
||||||
|
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
|
||||||
|
if (!broadcastOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||||
|
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
|
||||||
|
loadOp.evict(), loadOp.isVolatile());
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
#define GEN_PASS_CLASSES
|
||||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
|
@@ -37,12 +37,6 @@ def CombineGEPPattern : Pat<
|
|||||||
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
|
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
|
||||||
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||||
|
|
||||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
|
||||||
// => load(ptrs, broadcast(cond), other)
|
|
||||||
def CombineSelectMaskedLoadPattern : Pat<
|
|
||||||
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
|
|
||||||
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
|
|
||||||
|
|
||||||
// broadcast(cst) => cst
|
// broadcast(cst) => cst
|
||||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||||
def CombineBroadcastConstantPattern : Pat<
|
def CombineBroadcastConstantPattern : Pat<
|
||||||
|
@@ -391,6 +391,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
if (loads.contains(op->getResult(0))) {
|
if (loads.contains(op->getResult(0))) {
|
||||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||||
Value mask = loadOp.mask();
|
Value mask = loadOp.mask();
|
||||||
|
if (mask) {
|
||||||
Value splatCond = builder.create<triton::SplatOp>(
|
Value splatCond = builder.create<triton::SplatOp>(
|
||||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||||
Value newMask = builder.create<arith::AndIOp>(
|
Value newMask = builder.create<arith::AndIOp>(
|
||||||
@@ -399,6 +400,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
// once
|
// once
|
||||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||||
nextMapping.map(mask, newMask);
|
nextMapping.map(mask, newMask);
|
||||||
|
}
|
||||||
// TODO: more elegant way to do this?
|
// TODO: more elegant way to do this?
|
||||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||||
|
@@ -1442,21 +1442,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
mlir::triton::EvictionPolicy evictionPolicy,
|
mlir::triton::EvictionPolicy evictionPolicy,
|
||||||
bool isVolatile) -> mlir::Value {
|
bool isVolatile) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
|
|
||||||
std::vector<int64_t> shape = ptrType.getShape();
|
|
||||||
mlir::Type elementType = ptrType.getElementType()
|
|
||||||
.dyn_cast<mlir::triton::PointerType>()
|
|
||||||
.getPointeeType();
|
|
||||||
mlir::Type resultType =
|
|
||||||
mlir::RankedTensorType::get(shape, elementType);
|
|
||||||
if (other.has_value()) {
|
|
||||||
return self.create<mlir::triton::LoadOp>(
|
return self.create<mlir::triton::LoadOp>(
|
||||||
loc, resultType, ptrs, mask, other.value(), cacheModifier,
|
loc, ptrs, mask, other.value_or(mlir::Value()), cacheModifier,
|
||||||
evictionPolicy, isVolatile);
|
evictionPolicy, isVolatile);
|
||||||
} else {
|
|
||||||
return self.create<mlir::triton::LoadOp>(
|
|
||||||
loc, ptrs, mask, cacheModifier, evictionPolicy, isVolatile);
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.def("create_masked_store",
|
.def("create_masked_store",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
|
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
|
||||||
|
@@ -47,6 +47,6 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
|||||||
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||||
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
@@ -31,7 +31,7 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
|||||||
|
|
||||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
|
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.b32 [ $1 + 0 ], { $2 };",
|
||||||
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||||
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
tt.store %ptrs, %vs, %mask : tensor<128xf32>
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -186,7 +186,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||||
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -38,7 +38,7 @@ module {
|
|||||||
}
|
}
|
||||||
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||||
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
|
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
|
||||||
tt.store %17, %15#0, %6, : tensor<256xf32>
|
tt.store %17, %15#0, %6 : tensor<256xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -124,7 +124,7 @@ module {
|
|||||||
// }
|
// }
|
||||||
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||||
// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||||
// tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
@@ -41,6 +41,6 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|||||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||||
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
|
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
|
||||||
return
|
return
|
||||||
}
|
}
|
@@ -86,7 +86,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
|||||||
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||||
// CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
// CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||||
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: tt.store %22, %23, %cst_1, : tensor<64x64xf32, [[col_layout]]>
|
// CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||||
@@ -117,7 +117,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
|||||||
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
|
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
|
||||||
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
|
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
|
||||||
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
|
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
|
||||||
tt.store %24, %25, %26, : tensor<64x64xf32, #blocked4>
|
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,6 +170,6 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
|||||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||||
tt.store %20, %21, %22, : tensor<64x64xf32, #blocked1>
|
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||||
return
|
return
|
||||||
}
|
}
|
Reference in New Issue
Block a user