[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)
This commit is contained in:
@@ -1,10 +1,19 @@
|
|||||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
|
||||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
|
||||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
|
||||||
|
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||||
|
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
|
||||||
|
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
|
||||||
|
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
|
||||||
|
|
||||||
add_public_tablegen_target(TritonTableGen)
|
add_public_tablegen_target(TritonTableGen)
|
||||||
|
@@ -17,4 +17,24 @@
|
|||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
class DialectInferLayoutInterface
|
||||||
|
: public DialectInterface::Base<DialectInferLayoutInterface> {
|
||||||
|
public:
|
||||||
|
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
||||||
|
|
||||||
|
virtual LogicalResult
|
||||||
|
inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||||
|
Attribute &resultEncoding) const = 0;
|
||||||
|
|
||||||
|
virtual LogicalResult
|
||||||
|
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||||
|
Attribute &resultEncoding) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TRITON_IR_DIALECT_H_
|
#endif // TRITON_IR_DIALECT_H_
|
||||||
|
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
#ifndef TRITON_IR_INTERFACES_H_
|
||||||
|
#define TRITON_IR_INTERFACES_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
#define GET_TYPEDEF_CLASSES
|
||||||
|
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||||
|
|
||||||
|
#endif // TRITON_IR_TYPES_H_
|
@@ -10,50 +10,47 @@
|
|||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace OpTrait {
|
namespace OpTrait {
|
||||||
// TODO: should have `namespace triton {}` here
|
|
||||||
|
// These functions are out-of-line implementations of the methods in the
|
||||||
|
// corresponding trait classes. This avoids them being template
|
||||||
|
// instantiated/duplicated.
|
||||||
|
namespace impl {
|
||||||
|
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
|
||||||
|
LogicalResult verifySameOperandsEncoding(Operation *op);
|
||||||
|
// The rationale for this trait is to prevent users from creating programs
|
||||||
|
// that would have catastrophic register pressure and cause the compiler to
|
||||||
|
// hang.
|
||||||
|
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||||
|
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||||
|
// but we probably should limit number of elements (rather than bytes) to
|
||||||
|
// keep specs simple
|
||||||
|
int constexpr maxTensorNumElements = 1048576;
|
||||||
|
LogicalResult verifyTensorSize(Operation *op);
|
||||||
|
} // namespace impl
|
||||||
|
|
||||||
template <class ConcreteType>
|
template <class ConcreteType>
|
||||||
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
||||||
public:
|
public:
|
||||||
// TODO: move impl to .cc files
|
|
||||||
static LogicalResult verifyTrait(Operation *op) {
|
static LogicalResult verifyTrait(Operation *op) {
|
||||||
// The rationale for this number is to prevent users from creating programs
|
return impl::verifyTensorSize(op);
|
||||||
// that would have catastrophic register pressure and cause the compiler to
|
|
||||||
// hang.
|
|
||||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
|
||||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
|
||||||
// but we probably should limit number of elements (rather than bytes) to
|
|
||||||
// keep specs simple
|
|
||||||
int constexpr maxElement = 1048576;
|
|
||||||
for (auto opType : op->getOperandTypes()) {
|
|
||||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
|
||||||
int64_t numElements = 1;
|
|
||||||
for (int64_t s : tensorType.getShape())
|
|
||||||
numElements *= s;
|
|
||||||
if (numElements > maxElement)
|
|
||||||
return op->emitError("Maximum allowed number of elements is ")
|
|
||||||
<< maxElement << ", but " << *op << " has more than that";
|
|
||||||
if ((numElements & (numElements - 1)) != 0)
|
|
||||||
return op->emitError("Number of elements must be power-of-two, but ")
|
|
||||||
<< *op << " doesn't follow the rule";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
for (auto opType : op->getResultTypes()) {
|
template <typename ConcreteType>
|
||||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
class SameOperandsAndResultEncoding
|
||||||
int64_t numElements = 1;
|
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
|
||||||
for (int64_t s : tensorType.getShape())
|
public:
|
||||||
numElements *= s;
|
static LogicalResult verifyTrait(Operation *op) {
|
||||||
if (numElements > maxElement)
|
return impl::verifySameOperandsAndResultEncoding(op);
|
||||||
return op->emitError("Maximum allowed number of elements is ")
|
|
||||||
<< maxElement << ", but " << *op << " has more than that";
|
|
||||||
if ((numElements & (numElements - 1)) != 0)
|
|
||||||
return op->emitError("Number of elements must be power-of-two, but ")
|
|
||||||
<< *op << " doesn't follow the rule";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return success();
|
template <typename ConcreteType>
|
||||||
|
class SameOperandsEncoding
|
||||||
|
: public TraitBase<ConcreteType, SameOperandsEncoding> {
|
||||||
|
public:
|
||||||
|
static LogicalResult verifyTrait(Operation *op) {
|
||||||
|
return impl::verifySameOperandsEncoding(op);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -41,4 +41,7 @@ def Triton_Dialect : Dialect {
|
|||||||
let hasConstantMaterializer = 1;
|
let hasConstantMaterializer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||||
|
|
||||||
|
|
||||||
#endif // TRITON_DIALECT
|
#endif // TRITON_DIALECT
|
||||||
|
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
#ifndef TRITON_INTERFACES
|
||||||
|
#define TRITON_INTERFACES
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
#endif // TRITON_INTERFACES
|
@@ -4,18 +4,23 @@
|
|||||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||||
|
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||||
|
|
||||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||||
|
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||||
|
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Op Base
|
// Op Base
|
||||||
//
|
//
|
||||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])>;
|
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CastOps
|
// CastOps
|
||||||
@@ -25,7 +30,9 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
|||||||
// fptoui, fptosi, uitofp, sitofp,
|
// fptoui, fptosi, uitofp, sitofp,
|
||||||
// extf, tructf,
|
// extf, tructf,
|
||||||
// extui, extsi, tructi
|
// extui, extsi, tructi
|
||||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
|
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding,
|
||||||
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
let summary = "Cast int64 to pointer";
|
let summary = "Cast int64 to pointer";
|
||||||
|
|
||||||
@@ -36,7 +43,9 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
|
|||||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
|
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding,
|
||||||
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
let summary = "Cast pointer to int64";
|
let summary = "Cast pointer to int64";
|
||||||
|
|
||||||
@@ -47,7 +56,9 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffec
|
|||||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding,
|
||||||
|
NoSideEffect,
|
||||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||||
let summary = "Floating point casting for custom types";
|
let summary = "Floating point casting for custom types";
|
||||||
|
|
||||||
@@ -67,11 +78,33 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
|||||||
// TODO: We need a verifier here.
|
// TODO: We need a verifier here.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Pointer Arith Ops
|
||||||
|
//
|
||||||
|
|
||||||
|
def TT_AddPtrOp : TT_Op<"addptr",
|
||||||
|
[NoSideEffect,
|
||||||
|
SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding,
|
||||||
|
TypesMatchWith<"result type matches ptr type",
|
||||||
|
"result", "ptr", "$_self">,
|
||||||
|
TypesMatchWith<"result shape matches offset shape",
|
||||||
|
"result", "offset",
|
||||||
|
"getI32SameShape($_self)">]> {
|
||||||
|
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||||
|
|
||||||
|
let results = (outs TT_PtrLike:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Load/Store Ops
|
// Load/Store Ops
|
||||||
//
|
//
|
||||||
def TT_LoadOp : TT_Op<"load",
|
def TT_LoadOp : TT_Op<"load",
|
||||||
[SameOperandsAndResultShape,
|
[SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding,
|
||||||
SameVariadicOperandSize,
|
SameVariadicOperandSize,
|
||||||
MemoryEffects<[MemRead]>,
|
MemoryEffects<[MemRead]>,
|
||||||
TypesMatchWith<"infer ptr type from result type",
|
TypesMatchWith<"infer ptr type from result type",
|
||||||
@@ -110,6 +143,7 @@ def TT_LoadOp : TT_Op<"load",
|
|||||||
|
|
||||||
def TT_StoreOp : TT_Op<"store",
|
def TT_StoreOp : TT_Op<"store",
|
||||||
[SameOperandsShape,
|
[SameOperandsShape,
|
||||||
|
SameOperandsEncoding,
|
||||||
MemoryEffects<[MemWrite]>,
|
MemoryEffects<[MemWrite]>,
|
||||||
TypesMatchWith<"infer ptr type from value type",
|
TypesMatchWith<"infer ptr type from value type",
|
||||||
"value", "ptr",
|
"value", "ptr",
|
||||||
@@ -133,130 +167,11 @@ def TT_StoreOp : TT_Op<"store",
|
|||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_AddPtrOp : TT_Op<"addptr",
|
|
||||||
[NoSideEffect, SameOperandsAndResultShape,
|
|
||||||
TypesMatchWith<"result type matches ptr type",
|
|
||||||
"result", "ptr", "$_self">,
|
|
||||||
TypesMatchWith<"result shape matches offset shape",
|
|
||||||
"result", "offset",
|
|
||||||
"getI32SameShape($_self)">]> {
|
|
||||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
|
||||||
|
|
||||||
let results = (outs TT_PtrLike:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Shape Manipulation Ops
|
// Atomic Op
|
||||||
//
|
//
|
||||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, SameOperandsAndResultElementType]> {
|
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||||
let summary = "expand_dims";
|
SameOperandsAndResultEncoding]> {
|
||||||
|
|
||||||
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
|
|
||||||
let summary = "view";
|
|
||||||
|
|
||||||
let arguments = (ins TT_Tensor:$src);
|
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
|
||||||
let summary = "splat";
|
|
||||||
|
|
||||||
let arguments = (ins TT_Type:$src);
|
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
||||||
|
|
||||||
let hasFolder = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
|
|
||||||
let summary = "broadcast. No left-padding as of now.";
|
|
||||||
|
|
||||||
let arguments = (ins TT_Type:$src);
|
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
||||||
|
|
||||||
let hasFolder = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
|
||||||
let summary = "concatenate 2 tensors";
|
|
||||||
|
|
||||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// builtin Ops
|
|
||||||
//
|
|
||||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
|
||||||
let arguments = (ins I32Attr:$axis);
|
|
||||||
|
|
||||||
let results = (outs I32:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "attr-dict `:` type($result)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
|
||||||
let arguments = (ins I32Attr:$axis);
|
|
||||||
|
|
||||||
let results = (outs I32:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "attr-dict `:` type($result)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
|
||||||
TypesMatchWith<"result's type matches accumulator's type",
|
|
||||||
"d", "c", "$_self">]> {
|
|
||||||
let summary = "dot";
|
|
||||||
|
|
||||||
let description = [{
|
|
||||||
$d = matrix_multiply($a, $b) + $c
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
|
||||||
|
|
||||||
let results = (outs TT_FpIntTensor:$d);
|
|
||||||
|
|
||||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
|
||||||
|
|
||||||
// let hasCanonicalizer = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect]> {
|
|
||||||
let summary = "reduce";
|
|
||||||
|
|
||||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
|
||||||
|
|
||||||
let builders = [
|
|
||||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
|
||||||
];
|
|
||||||
|
|
||||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
|
||||||
let summary = "atomic rmw";
|
let summary = "atomic rmw";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -271,7 +186,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
|||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding]> {
|
||||||
let summary = "atomic cas";
|
let summary = "atomic cas";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -289,10 +205,133 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
|||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// External Function Ops
|
// Shape Manipulation Ops
|
||||||
|
//
|
||||||
|
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||||
|
SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "splat";
|
||||||
|
|
||||||
|
let arguments = (ins TT_Type:$src);
|
||||||
|
|
||||||
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
|
SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "expand_dims";
|
||||||
|
|
||||||
|
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
||||||
|
|
||||||
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||||
|
SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "view";
|
||||||
|
|
||||||
|
let arguments = (ins TT_Tensor:$src);
|
||||||
|
|
||||||
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||||
|
SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "broadcast. No left-padding as of now.";
|
||||||
|
|
||||||
|
let arguments = (ins TT_Type:$src);
|
||||||
|
|
||||||
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||||
|
SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "concatenate 2 tensors";
|
||||||
|
|
||||||
|
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||||
|
|
||||||
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// SPMD Ops
|
||||||
|
//
|
||||||
|
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||||
|
let arguments = (ins I32Attr:$axis);
|
||||||
|
|
||||||
|
let results = (outs I32:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict `:` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||||
|
let arguments = (ins I32Attr:$axis);
|
||||||
|
|
||||||
|
let results = (outs I32:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict `:` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Dot Op
|
||||||
|
//
|
||||||
|
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
|
TypesMatchWith<"result's type matches accumulator's type",
|
||||||
|
"d", "c", "$_self">]> {
|
||||||
|
let summary = "dot";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
$d = matrix_multiply($a, $b) + $c
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||||
|
|
||||||
|
let results = (outs TT_FpIntTensor:$d);
|
||||||
|
|
||||||
|
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Reduce Op
|
||||||
|
//
|
||||||
|
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
|
let summary = "reduce";
|
||||||
|
|
||||||
|
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||||
|
|
||||||
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||||
|
];
|
||||||
|
|
||||||
|
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// External elementwise op
|
||||||
//
|
//
|
||||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultEncoding,
|
||||||
SameVariadicOperandSize]> {
|
SameVariadicOperandSize]> {
|
||||||
let summary = "ext_elemwise";
|
let summary = "ext_elemwise";
|
||||||
|
|
||||||
@@ -307,10 +346,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe
|
|||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Intrinsics
|
// Make Range Op
|
||||||
//
|
//
|
||||||
// TODO: should have ConstantLike as Trait
|
// TODO: should have ConstantLike as Trait
|
||||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||||
|
@@ -11,6 +11,7 @@
|
|||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||||
|
|
||||||
#define GET_ATTRDEF_CLASSES
|
#define GET_ATTRDEF_CLASSES
|
||||||
|
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#define TRITONGPU_ATTRDEFS
|
#define TRITONGPU_ATTRDEFS
|
||||||
|
|
||||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||||
|
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TritonGPU Attribute Definitions
|
// TritonGPU Attribute Definitions
|
||||||
@@ -34,6 +35,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
|||||||
|
|
||||||
code extraBaseClassDeclaration = [{
|
code extraBaseClassDeclaration = [{
|
||||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
||||||
|
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,15 +303,15 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
|||||||
TODO: improve docs
|
TODO: improve docs
|
||||||
|
|
||||||
A = [x x x x x x x x]
|
A = [x x x x x x x x]
|
||||||
[x x x x x x x x]
|
|
||||||
L_parent = [0 1 2 3 ]
|
parent = [0 1 2 3 ]
|
||||||
[4 5 6 7 ]
|
[4 5 6 7 ]
|
||||||
[8 9 10 11]
|
[8 9 10 11]
|
||||||
[12 13 14 15]
|
[12 13 14 15]
|
||||||
dim = 0
|
dim = 0
|
||||||
|
|
||||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15} ]
|
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
|
||||||
|
|
||||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||||
|
|
||||||
|
@@ -76,15 +76,4 @@ def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::Modu
|
|||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
|
|
||||||
let summary = "verify TritonGPU IR";
|
|
||||||
|
|
||||||
let description = [{}];
|
|
||||||
|
|
||||||
let constructor = "mlir::createTritonGPUVerifier()";
|
|
||||||
|
|
||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
|
||||||
"mlir::gpu::GPUDialect"];
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -185,9 +185,16 @@ struct TritonExpandDimsPattern
|
|||||||
// return type
|
// return type
|
||||||
RankedTensorType retType =
|
RankedTensorType retType =
|
||||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||||
|
// convert operand to slice of return type
|
||||||
|
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
||||||
|
getContext(), op.axis(), retEncoding);
|
||||||
|
RankedTensorType newArgType = RankedTensorType::get(
|
||||||
|
argType.getShape(), argType.getElementType(), newArgEncoding);
|
||||||
// construct new op
|
// construct new op
|
||||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
|
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
op, retType, adaptor.src(), adaptor.axis());
|
op.getLoc(), newArgType, adaptor.src());
|
||||||
|
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
|
||||||
|
adaptor.axis());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -310,9 +317,8 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
||||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
add_mlir_dialect_library(TritonIR
|
add_mlir_dialect_library(TritonIR
|
||||||
|
Interfaces.cpp
|
||||||
Dialect.cpp
|
Dialect.cpp
|
||||||
Ops.cpp
|
Ops.cpp
|
||||||
Types.cpp
|
Types.cpp
|
||||||
|
Traits.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
TritonTableGen
|
TritonTableGen
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/IR/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
|
|
||||||
|
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
@@ -160,18 +160,88 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
//-- DotOp --
|
//-- DotOp --
|
||||||
|
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
// type is the same as the accumulator
|
||||||
|
auto accTy = operands[2].getType().cast<RankedTensorType>();
|
||||||
|
inferredReturnTypes.push_back(accTy);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//-- ReduceOp --
|
||||||
|
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
// infer shape
|
||||||
|
Value arg = operands[0];
|
||||||
|
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||||
|
auto retShape = argTy.getShape().vec();
|
||||||
|
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||||
|
retShape.erase(retShape.begin() + axis);
|
||||||
|
// infer encoding
|
||||||
|
Attribute argEncoding = argTy.getEncoding();
|
||||||
|
Attribute retEncoding;
|
||||||
|
if (argEncoding) {
|
||||||
|
Dialect &dialect = argEncoding.getDialect();
|
||||||
|
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||||
|
if (inferLayoutInterface
|
||||||
|
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
||||||
|
.failed()) {
|
||||||
|
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// create type
|
||||||
|
auto argEltTy = argTy.getElementType();
|
||||||
|
inferredReturnTypes.push_back(
|
||||||
|
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
//-- SplatOp --
|
//-- SplatOp --
|
||||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||||
if (!constOperand)
|
if (!constOperand)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto shapedType = getType().cast<ShapedType>();
|
auto shapedType = getType().cast<ShapedType>();
|
||||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//-- ExpandDimsOp --
|
||||||
|
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
// infer shape
|
||||||
|
auto arg = operands[0];
|
||||||
|
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||||
|
auto retShape = argTy.getShape().vec();
|
||||||
|
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||||
|
retShape.insert(retShape.begin() + axis, 1);
|
||||||
|
// infer encoding
|
||||||
|
Attribute argEncoding = argTy.getEncoding();
|
||||||
|
Attribute retEncoding;
|
||||||
|
if (argEncoding) {
|
||||||
|
Dialect &dialect = argEncoding.getDialect();
|
||||||
|
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||||
|
if (inferLayoutInterface
|
||||||
|
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding)
|
||||||
|
.failed()) {
|
||||||
|
llvm::report_fatal_error("failed to infer layout for ExpandDimsOp");
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// create type
|
||||||
|
auto argEltTy = argTy.getElementType();
|
||||||
|
inferredReturnTypes.push_back(
|
||||||
|
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
//-- BroadcastOp --
|
//-- BroadcastOp --
|
||||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||||
|
69
lib/Dialect/Triton/IR/Traits.cpp
Normal file
69
lib/Dialect/Triton/IR/Traits.cpp
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||||
|
|
||||||
|
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
|
||||||
|
using namespace mlir;
|
||||||
|
auto encA = tyA.dyn_cast<RankedTensorType>();
|
||||||
|
auto encB = tyA.dyn_cast<RankedTensorType>();
|
||||||
|
if (!encA || !encB)
|
||||||
|
return success();
|
||||||
|
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
|
||||||
|
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
||||||
|
failed(verifyAtLeastNResults(op, 1)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto type = op->getOperand(0).getType();
|
||||||
|
for (auto resultType : op->getResultTypes())
|
||||||
|
if (failed(verifySameEncoding(resultType, type)))
|
||||||
|
return op->emitOpError()
|
||||||
|
<< "requires the same shape for all operands and results";
|
||||||
|
return verifySameOperandsEncoding(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
|
||||||
|
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto type = op->getOperand(0).getType();
|
||||||
|
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||||
|
if (failed(verifySameEncoding(opType, type)))
|
||||||
|
return op->emitOpError() << "requires the same encoding for all operands";
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||||
|
for (auto opType : op->getOperandTypes()) {
|
||||||
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t s : tensorType.getShape())
|
||||||
|
numElements *= s;
|
||||||
|
if (numElements > maxTensorNumElements)
|
||||||
|
return op->emitError("Maximum allowed number of elements is ")
|
||||||
|
<< maxTensorNumElements << ", but " << *op
|
||||||
|
<< " has more than that";
|
||||||
|
if ((numElements & (numElements - 1)) != 0)
|
||||||
|
return op->emitError("Number of elements must be power-of-two, but ")
|
||||||
|
<< *op << " doesn't follow the rule";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto opType : op->getResultTypes()) {
|
||||||
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t s : tensorType.getShape())
|
||||||
|
numElements *= s;
|
||||||
|
if (numElements > maxTensorNumElements)
|
||||||
|
return op->emitError("Maximum allowed number of elements is ")
|
||||||
|
<< maxTensorNumElements << ", but " << *op
|
||||||
|
<< " has more than that";
|
||||||
|
if ((numElements & (numElements - 1)) != 0)
|
||||||
|
return op->emitError("Number of elements must be power-of-two, but ")
|
||||||
|
<< *op << " doesn't follow the rule";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
@@ -3,6 +3,7 @@
|
|||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "triton/Analysis/Utility.h"
|
#include "triton/Analysis/Utility.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
@@ -288,8 +289,9 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return parser.getChecked<BlockedEncodingAttr>(
|
auto ret = parser.getChecked<BlockedEncodingAttr>(
|
||||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
@@ -346,27 +348,13 @@ void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
|||||||
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
if (parser.parseLess().failed())
|
if (parser.parseLess().failed())
|
||||||
return {};
|
return {};
|
||||||
// Parse the data as a dictionary
|
NamedAttrList attrs;
|
||||||
DictionaryAttr dict;
|
if (parser.parseOptionalAttrDict(attrs).failed())
|
||||||
if (parser.parseAttribute(dict).failed())
|
|
||||||
return {};
|
return {};
|
||||||
if (parser.parseGreater().failed())
|
if (parser.parseGreater().failed())
|
||||||
return {};
|
return {};
|
||||||
|
unsigned dim = attrs.get("dim").cast<IntegerAttr>().getInt();
|
||||||
unsigned dim = 0;
|
Attribute parent = attrs.get("parent");
|
||||||
Attribute parent;
|
|
||||||
|
|
||||||
for (const NamedAttribute &attr : dict) {
|
|
||||||
if (attr.getName() == "dim") {
|
|
||||||
if (parseUInt(parser, attr, dim, "dim").failed())
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
if (attr.getName() == "parent") {
|
|
||||||
if (parser.parseAttribute(parent).failed())
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -522,6 +510,35 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TritonGPUInferLayoutInterface
|
||||||
|
: public triton::DialectInferLayoutInterface {
|
||||||
|
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
||||||
|
|
||||||
|
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||||
|
Attribute &resultEncoding) const {
|
||||||
|
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
||||||
|
operandEncoding);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||||
|
Attribute &resultEncoding) const {
|
||||||
|
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
||||||
|
if (!sliceEncoding) {
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"ExpandDimsOp operand encoding must be SliceEncodingAttr");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (sliceEncoding.getDim() != axis) {
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"Incompatible slice dimension for ExpandDimsOp operand");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
resultEncoding = sliceEncoding.getParent();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void TritonGPUDialect::initialize() {
|
void TritonGPUDialect::initialize() {
|
||||||
addAttributes<
|
addAttributes<
|
||||||
#define GET_ATTRDEF_LIST
|
#define GET_ATTRDEF_LIST
|
||||||
@@ -532,6 +549,7 @@ void TritonGPUDialect::initialize() {
|
|||||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||||
>();
|
>();
|
||||||
addInterfaces<TritonGPUOpAsmInterface>();
|
addInterfaces<TritonGPUOpAsmInterface>();
|
||||||
|
addInterfaces<TritonGPUInferLayoutInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms
|
|||||||
Combine.cpp
|
Combine.cpp
|
||||||
Pipeline.cpp
|
Pipeline.cpp
|
||||||
Swizzle.cpp
|
Swizzle.cpp
|
||||||
Verifier.cpp
|
|
||||||
TritonGPUConversion.cpp
|
TritonGPUConversion.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
@@ -1,106 +0,0 @@
|
|||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
||||||
|
|
||||||
class TritonGPUVerifier : public TritonGPUVerifierBase<TritonGPUVerifier> {
|
|
||||||
public:
|
|
||||||
void runOnOperation() override {
|
|
||||||
MLIRContext *context = &getContext();
|
|
||||||
ModuleOp m = getOperation();
|
|
||||||
|
|
||||||
// The idea is similar to mlir/lib/IR/Verifier.cpp
|
|
||||||
verifyImpl(m.getOperation());
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
LogicalResult verifySingleOp(Operation *op) {
|
|
||||||
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(op)) {
|
|
||||||
Type aType = dotOp.a().getType();
|
|
||||||
Type bType = dotOp.b().getType();
|
|
||||||
Type cType = dotOp.c().getType();
|
|
||||||
Type dType = dotOp.d().getType();
|
|
||||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{aType, bType},
|
|
||||||
llvm::SmallVector<char>{'a', 'b'})) {
|
|
||||||
Type type = std::get<0>(it);
|
|
||||||
char name = std::get<1>(it);
|
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
||||||
Attribute encoding = tensorType.getEncoding();
|
|
||||||
if (!encoding)
|
|
||||||
return dotOp.emitError() << name << " should have encoding";
|
|
||||||
if (!encoding.isa<triton::gpu::SharedEncodingAttr>())
|
|
||||||
return dotOp.emitError() << name << " should be of shared layout";
|
|
||||||
} else
|
|
||||||
return dotOp.emitError()
|
|
||||||
<< name << "'s type should be of RankedTensorType";
|
|
||||||
}
|
|
||||||
|
|
||||||
Attribute cLayout;
|
|
||||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
|
|
||||||
llvm::SmallVector<char>{'c', 'd'})) {
|
|
||||||
Type type = std::get<0>(it);
|
|
||||||
char name = std::get<1>(it);
|
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
||||||
Attribute encoding = tensorType.getEncoding();
|
|
||||||
if (!encoding)
|
|
||||||
return dotOp.emitError() << name << " should have encoding";
|
|
||||||
if (!encoding.isa<triton::gpu::MmaEncodingAttr>() &&
|
|
||||||
!encoding.isa<triton::gpu::BlockedEncodingAttr>())
|
|
||||||
return dotOp.emitError()
|
|
||||||
<< name << " should be of distributed layout";
|
|
||||||
if (name == 'c')
|
|
||||||
cLayout = encoding;
|
|
||||||
else if (encoding != cLayout)
|
|
||||||
return dotOp.emitError() << "d & c should have the same layout";
|
|
||||||
} else
|
|
||||||
return dotOp.emitError()
|
|
||||||
<< name << "'s type should be of RankedTensorType";
|
|
||||||
}
|
|
||||||
|
|
||||||
// signalPassFailure();
|
|
||||||
}
|
|
||||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
|
||||||
// TODO: fill this
|
|
||||||
}
|
|
||||||
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
|
|
||||||
// TODO: fill this
|
|
||||||
}
|
|
||||||
if (auto addptrOp = llvm::dyn_cast<triton::AddPtrOp>(op)) {
|
|
||||||
// TODO: fill this
|
|
||||||
}
|
|
||||||
// Triton builtin Ops
|
|
||||||
if (llvm::isa<triton::GetProgramIdOp, triton::GetNumProgramsOp,
|
|
||||||
triton::MakeRangeOp>(op)) {
|
|
||||||
// TODO: fill this
|
|
||||||
}
|
|
||||||
if (auto atomicRmw = llvm::dyn_cast<triton::AtomicRMWOp>(op)) {
|
|
||||||
// TODO: fill this
|
|
||||||
}
|
|
||||||
if (auto atomicCas = llvm::dyn_cast<triton::AtomicCASOp>(op)) {
|
|
||||||
// TODO: fill this
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Arithmetic, SCF, TritonGPU ops
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void verifyImpl(Operation *op) {
|
|
||||||
if (verifySingleOp(op).failed())
|
|
||||||
signalPassFailure();
|
|
||||||
|
|
||||||
// verify that all child regions are ok
|
|
||||||
for (Region ®ion : op->getRegions())
|
|
||||||
for (Block &block : region)
|
|
||||||
for (Operation &childOp : block)
|
|
||||||
verifyImpl(&childOp);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createTritonGPUVerifier() {
|
|
||||||
return std::make_unique<TritonGPUVerifier>();
|
|
||||||
}
|
|
@@ -1182,10 +1182,6 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_verifier_pass",
|
|
||||||
[](mlir::PassManager &self) {
|
|
||||||
self.addPass(mlir::createTritonGPUVerifier());
|
|
||||||
})
|
|
||||||
.def("add_triton_gpu_to_llvm",
|
.def("add_triton_gpu_to_llvm",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||||
|
@@ -861,7 +861,6 @@ def optimize_tritongpu_ir(mod, num_stages):
|
|||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
pm.add_coalesce_pass()
|
pm.add_coalesce_pass()
|
||||||
pm.add_triton_gpu_combine_pass()
|
pm.add_triton_gpu_combine_pass()
|
||||||
pm.add_triton_gpu_verifier_pass()
|
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
|
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
|
||||||
|
|
||||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
|
||||||
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
@@ -164,7 +165,7 @@ func @alloc(%A : !tt.ptr<f16>) {
|
|||||||
func @scratch() {
|
func @scratch() {
|
||||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||||
// CHECK: scratch offset = 0, size = 512
|
// CHECK: scratch offset = 0, size = 512
|
||||||
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #AL>
|
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||||
return
|
return
|
||||||
// CHECK-NEXT: size = 512
|
// CHECK-NEXT: size = 512
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s
|
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s
|
||||||
|
|
||||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
|
||||||
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
@@ -69,7 +70,8 @@ func @scratch() {
|
|||||||
// CHECK: Membar 1
|
// CHECK: Membar 1
|
||||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
// CHECK-NEXT: Membar 3
|
// CHECK-NEXT: Membar 3
|
||||||
%b = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #A> -> tensor<16xf16, #A>
|
%aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL>
|
||||||
|
%b = tt.reduce %aa {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -348,15 +348,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||||
|
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK-LABEL: basic_insert_slice_async_v4
|
// CHECK-LABEL: basic_insert_slice_async_v4
|
||||||
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #block1>
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #block1>) -> tensor<1x64xi32, #block3>
|
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
|
||||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||||
%cst_scalar = arith.constant 64 : i32
|
%cst_scalar = arith.constant 64 : i32
|
||||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||||
@@ -387,15 +389,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||||
|
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK-LABEL: basic_insert_slice_async_v1
|
// CHECK-LABEL: basic_insert_slice_async_v1
|
||||||
func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block1>
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block1>) -> tensor<1x32xi32, #block3>
|
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3>
|
||||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2>
|
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2>
|
||||||
%cst_scalar = arith.constant 32 : i32
|
%cst_scalar = arith.constant 32 : i32
|
||||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2>
|
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2>
|
||||||
@@ -429,15 +433,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||||
|
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
||||||
func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||||
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
|
||||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #block0>) -> tensor<32x1xi32, #block2>
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
|
||||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block0>) -> tensor<1x32xi32, #block3>
|
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3>
|
||||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<32x1xi32, #block2>) -> tensor<32x32xi32, #block2>
|
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<32x1xi32, #block2>) -> tensor<32x32xi32, #block2>
|
||||||
%cst_scalar = arith.constant 32 : i32
|
%cst_scalar = arith.constant 32 : i32
|
||||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<32x32xi32, #block2>
|
%cst = tt.splat %cst_scalar : (i32) -> tensor<32x32xi32, #block2>
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize -tritongpu-verifier | FileCheck %s
|
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||||
|
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
|
||||||
|
#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}>
|
||||||
|
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
|
||||||
@@ -23,13 +25,14 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|||||||
%arg3: i32 {tt.divisibility = 16 : i32}) {
|
%arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||||
|
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||||
|
@@ -53,7 +53,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||||
|
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||||
|
#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}>
|
||||||
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||||
|
|
||||||
@@ -90,13 +92,14 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
|||||||
// CHECK: return
|
// CHECK: return
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||||
|
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||||
@@ -138,13 +141,14 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
|||||||
%c32 = arith.constant 32 : index
|
%c32 = arith.constant 32 : index
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||||
|
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s
|
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s
|
||||||
|
|
||||||
// 4 warps
|
// 4 warps
|
||||||
// matmul: 128x32 @ 32x128 -> 128x128
|
// matmul: 128x32 @ 32x128 -> 128x128
|
||||||
|
Reference in New Issue
Block a user