[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)
This commit is contained in:
@@ -1,10 +1,19 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
|
||||
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
|
||||
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
|
||||
|
||||
add_public_tablegen_target(TritonTableGen)
|
||||
|
@@ -17,4 +17,24 @@
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
class DialectInferLayoutInterface
|
||||
: public DialectInterface::Base<DialectInferLayoutInterface> {
|
||||
public:
|
||||
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
virtual LogicalResult
|
||||
inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
|
||||
virtual LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
||||
|
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#ifndef TRITON_IR_INTERFACES_H_
|
||||
#define TRITON_IR_INTERFACES_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
@@ -10,50 +10,47 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
// TODO: should have `namespace triton {}` here
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
|
||||
LogicalResult verifySameOperandsEncoding(Operation *op);
|
||||
// The rationale for this trait is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
// keep specs simple
|
||||
int constexpr maxTensorNumElements = 1048576;
|
||||
LogicalResult verifyTensorSize(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <class ConcreteType>
|
||||
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
||||
public:
|
||||
// TODO: move impl to .cc files
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
// The rationale for this number is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
// keep specs simple
|
||||
int constexpr maxElement = 1048576;
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxElement)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxElement << ", but " << *op << " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
return impl::verifyTensorSize(op);
|
||||
}
|
||||
};
|
||||
|
||||
for (auto opType : op->getResultTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxElement)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxElement << ", but " << *op << " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsAndResultEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
return success();
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -41,4 +41,7 @@ def Triton_Dialect : Dialect {
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
|
||||
|
||||
#endif // TRITON_DIALECT
|
||||
|
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
@@ -0,0 +1,6 @@
|
||||
#ifndef TRITON_INTERFACES
|
||||
#define TRITON_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
#endif // TRITON_INTERFACES
|
@@ -4,18 +4,23 @@
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
|
||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
|
||||
|
||||
//
|
||||
// Op Base
|
||||
//
|
||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])>;
|
||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
|
||||
}
|
||||
|
||||
//
|
||||
// CastOps
|
||||
@@ -25,7 +30,9 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
@@ -36,7 +43,9 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
@@ -47,7 +56,9 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffec
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
@@ -67,11 +78,33 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
|
||||
//
|
||||
// Pointer Arith Ops
|
||||
//
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
@@ -110,6 +143,7 @@ def TT_LoadOp : TT_Op<"load",
|
||||
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
[SameOperandsShape,
|
||||
SameOperandsEncoding,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"value", "ptr",
|
||||
@@ -133,130 +167,11 @@ def TT_StoreOp : TT_Op<"store",
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect, SameOperandsAndResultShape,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
// Atomic Op
|
||||
//
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
|
||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// builtin Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
|
||||
// let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "atomic rmw";
|
||||
|
||||
let description = [{
|
||||
@@ -271,7 +186,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "atomic cas";
|
||||
|
||||
let description = [{
|
||||
@@ -289,10 +205,133 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// External Function Ops
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
|
||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Dot Op
|
||||
//
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
//
|
||||
// Reduce Op
|
||||
//
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// External elementwise op
|
||||
//
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
let summary = "ext_elemwise";
|
||||
|
||||
@@ -307,10 +346,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// Intrinsics
|
||||
// Make Range Op
|
||||
//
|
||||
// TODO: should have ConstantLike as Trait
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#define TRITONGPU_ATTRDEFS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TritonGPU Attribute Definitions
|
||||
@@ -34,6 +35,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
||||
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -301,15 +303,15 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
TODO: improve docs
|
||||
|
||||
A = [x x x x x x x x]
|
||||
[x x x x x x x x]
|
||||
L_parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
|
||||
parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
dim = 0
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15} ]
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
|
||||
|
||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||
|
||||
|
@@ -76,15 +76,4 @@ def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::Modu
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
|
||||
let summary = "verify TritonGPU IR";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUVerifier()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::gpu::GPUDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -185,9 +185,16 @@ struct TritonExpandDimsPattern
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
// convert operand to slice of return type
|
||||
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
||||
getContext(), op.axis(), retEncoding);
|
||||
RankedTensorType newArgType = RankedTensorType::get(
|
||||
argType.getShape(), argType.getElementType(), newArgEncoding);
|
||||
// construct new op
|
||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
|
||||
op, retType, adaptor.src(), adaptor.axis());
|
||||
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op.getLoc(), newArgType, adaptor.src());
|
||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
|
||||
adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -310,9 +317,8 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@@ -1,7 +1,9 @@
|
||||
add_mlir_dialect_library(TritonIR
|
||||
Interfaces.cpp
|
||||
Dialect.cpp
|
||||
Ops.cpp
|
||||
Types.cpp
|
||||
Traits.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonTableGen
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
@@ -160,18 +160,88 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
}
|
||||
|
||||
//-- DotOp --
|
||||
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// type is the same as the accumulator
|
||||
auto accTy = operands[2].getType().cast<RankedTensorType>();
|
||||
inferredReturnTypes.push_back(accTy);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- ReduceOp --
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
Value arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.erase(retShape.begin() + axis);
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// create type
|
||||
auto argEltTy = argTy.getElementType();
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- SplatOp --
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
return ret;
|
||||
}
|
||||
|
||||
//-- ExpandDimsOp --
|
||||
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
auto arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.insert(retShape.begin() + axis, 1);
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ExpandDimsOp");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// create type
|
||||
auto argEltTy = argTy.getElementType();
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- BroadcastOp --
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
|
69
lib/Dialect/Triton/IR/Traits.cpp
Normal file
69
lib/Dialect/Triton/IR/Traits.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
|
||||
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
|
||||
using namespace mlir;
|
||||
auto encA = tyA.dyn_cast<RankedTensorType>();
|
||||
auto encB = tyA.dyn_cast<RankedTensorType>();
|
||||
if (!encA || !encB)
|
||||
return success();
|
||||
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
||||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto resultType : op->getResultTypes())
|
||||
if (failed(verifySameEncoding(resultType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
return verifySameOperandsEncoding(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (failed(verifySameEncoding(opType, type)))
|
||||
return op->emitOpError() << "requires the same encoding for all operands";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxTensorNumElements)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxTensorNumElements << ", but " << *op
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
for (auto opType : op->getResultTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxTensorNumElements)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxTensorNumElements << ", but " << *op
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
@@ -288,8 +289,9 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<BlockedEncodingAttr>(
|
||||
auto ret = parser.getChecked<BlockedEncodingAttr>(
|
||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
@@ -346,27 +348,13 @@ void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned dim = 0;
|
||||
Attribute parent;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "dim") {
|
||||
if (parseUInt(parser, attr, dim, "dim").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "parent") {
|
||||
if (parser.parseAttribute(parent).failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
unsigned dim = attrs.get("dim").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
||||
}
|
||||
|
||||
@@ -522,6 +510,35 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonGPUInferLayoutInterface
|
||||
: public triton::DialectInferLayoutInterface {
|
||||
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
||||
|
||||
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
||||
operandEncoding);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
||||
if (!sliceEncoding) {
|
||||
llvm::report_fatal_error(
|
||||
"ExpandDimsOp operand encoding must be SliceEncodingAttr");
|
||||
return failure();
|
||||
}
|
||||
if (sliceEncoding.getDim() != axis) {
|
||||
llvm::report_fatal_error(
|
||||
"Incompatible slice dimension for ExpandDimsOp operand");
|
||||
return failure();
|
||||
}
|
||||
resultEncoding = sliceEncoding.getParent();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
@@ -532,6 +549,7 @@ void TritonGPUDialect::initialize() {
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
addInterfaces<TritonGPUInferLayoutInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -568,4 +586,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
}
|
@@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Swizzle.cpp
|
||||
Verifier.cpp
|
||||
TritonGPUConversion.cpp
|
||||
|
||||
DEPENDS
|
||||
|
@@ -1,106 +0,0 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUVerifier : public TritonGPUVerifierBase<TritonGPUVerifier> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
// The idea is similar to mlir/lib/IR/Verifier.cpp
|
||||
verifyImpl(m.getOperation());
|
||||
}
|
||||
|
||||
private:
|
||||
LogicalResult verifySingleOp(Operation *op) {
|
||||
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(op)) {
|
||||
Type aType = dotOp.a().getType();
|
||||
Type bType = dotOp.b().getType();
|
||||
Type cType = dotOp.c().getType();
|
||||
Type dType = dotOp.d().getType();
|
||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{aType, bType},
|
||||
llvm::SmallVector<char>{'a', 'b'})) {
|
||||
Type type = std::get<0>(it);
|
||||
char name = std::get<1>(it);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of shared layout";
|
||||
} else
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
Attribute cLayout;
|
||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
|
||||
llvm::SmallVector<char>{'c', 'd'})) {
|
||||
Type type = std::get<0>(it);
|
||||
char name = std::get<1>(it);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::MmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::BlockedEncodingAttr>())
|
||||
return dotOp.emitError()
|
||||
<< name << " should be of distributed layout";
|
||||
if (name == 'c')
|
||||
cLayout = encoding;
|
||||
else if (encoding != cLayout)
|
||||
return dotOp.emitError() << "d & c should have the same layout";
|
||||
} else
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
// signalPassFailure();
|
||||
}
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto addptrOp = llvm::dyn_cast<triton::AddPtrOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
// Triton builtin Ops
|
||||
if (llvm::isa<triton::GetProgramIdOp, triton::GetNumProgramsOp,
|
||||
triton::MakeRangeOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto atomicRmw = llvm::dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto atomicCas = llvm::dyn_cast<triton::AtomicCASOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
|
||||
// TODO: Arithmetic, SCF, TritonGPU ops
|
||||
return success();
|
||||
}
|
||||
|
||||
void verifyImpl(Operation *op) {
|
||||
if (verifySingleOp(op).failed())
|
||||
signalPassFailure();
|
||||
|
||||
// verify that all child regions are ok
|
||||
for (Region ®ion : op->getRegions())
|
||||
for (Block &block : region)
|
||||
for (Operation &childOp : block)
|
||||
verifyImpl(&childOp);
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUVerifier() {
|
||||
return std::make_unique<TritonGPUVerifier>();
|
||||
}
|
@@ -1182,10 +1182,6 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||
})
|
||||
.def("add_triton_gpu_verifier_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUVerifier());
|
||||
})
|
||||
.def("add_triton_gpu_to_llvm",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
|
@@ -861,7 +861,6 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.add_cse_pass()
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_verifier_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
|
||||
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
|
||||
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
@@ -164,7 +165,7 @@ func @alloc(%A : !tt.ptr<f16>) {
|
||||
func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
// CHECK: scratch offset = 0, size = 512
|
||||
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #AL>
|
||||
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s
|
||||
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
|
||||
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
@@ -69,7 +70,8 @@ func @scratch() {
|
||||
// CHECK: Membar 1
|
||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||
// CHECK-NEXT: Membar 3
|
||||
%b = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #A> -> tensor<16xf16, #A>
|
||||
%aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL>
|
||||
%b = tt.reduce %aa {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -348,15 +348,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v4
|
||||
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #block1>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #block1>) -> tensor<1x64xi32, #block3>
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||
%cst_scalar = arith.constant 64 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||
@@ -387,15 +389,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1
|
||||
func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block1>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block1>) -> tensor<1x32xi32, #block3>
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2>
|
||||
%cst_scalar = arith.constant 32 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2>
|
||||
@@ -429,15 +433,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
||||
func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #block0>) -> tensor<32x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block0>) -> tensor<1x32xi32, #block3>
|
||||
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<32x1xi32, #block2>) -> tensor<32x32xi32, #block2>
|
||||
%cst_scalar = arith.constant 32 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<32x32xi32, #block2>
|
||||
|
@@ -1,8 +1,10 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize -tritongpu-verifier | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
|
||||
#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
@@ -23,13 +25,14 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
||||
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
|
@@ -53,7 +53,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}>
|
||||
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
@@ -90,13 +92,14 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
// CHECK: return
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
||||
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
@@ -138,13 +141,14 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
||||
%00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1>
|
||||
%01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0>
|
||||
%1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s
|
||||
|
||||
// 4 warps
|
||||
// matmul: 128x32 @ 32x128 -> 128x128
|
||||
|
Reference in New Issue
Block a user