assembly format
This commit is contained in:
@@ -212,6 +212,8 @@ def TT_ReduceOp : TT_Op<"reduce"> {
|
|||||||
// let builders = [
|
// let builders = [
|
||||||
// OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>,
|
// OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>,
|
||||||
// ];
|
// ];
|
||||||
|
|
||||||
|
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||||
|
@@ -35,7 +35,13 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||||
[MemoryEffects<[MemRead, MemWrite]>]> {
|
[MemoryEffects<[MemRead, MemWrite]>,
|
||||||
|
TypesMatchWith<"infer mask type from ptr type",
|
||||||
|
"ptr", "mask",
|
||||||
|
"getI1SameShape($_self)">,
|
||||||
|
TypesMatchWith<"infer other type from ptr type",
|
||||||
|
"ptr", "other",
|
||||||
|
"getPointeeType($_self)">]> {
|
||||||
let summary = "copy async";
|
let summary = "copy async";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
||||||
@@ -44,7 +50,7 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
|||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
// let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
|
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||||
|
@@ -58,6 +58,34 @@ void TritonGPUDialect::initialize() {
|
|||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
// Type inference
|
||||||
|
static Type getI1SameShape(Type type) {
|
||||||
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||||
|
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Type getPointeeType(Type type) {
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
// Tensor of pointers
|
||||||
|
auto shape = tensorType.getShape();
|
||||||
|
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
||||||
|
Type pointeeType = ptrType.getPointeeType();
|
||||||
|
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
||||||
|
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
||||||
|
// scalar pointer
|
||||||
|
Type pointeeType = ptrType.getPointeeType();
|
||||||
|
return pointeeType;
|
||||||
|
}
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||||
|
Reference in New Issue
Block a user