From 441fd7c3ccad06566e84c4b1db3350979c1ce81a Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 25 May 2022 17:53:24 +0800 Subject: [PATCH] assembly format --- include/triton/Dialect/Triton/IR/TritonOps.td | 2 ++ .../Dialect/TritonGPU/IR/TritonGPUOps.td | 10 +++++-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 28 +++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 445d2f4f5..24a9ea40c 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -212,6 +212,8 @@ def TT_ReduceOp : TT_Op<"reduce"> { // let builders = [ // OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>, // ]; + + let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; } def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index ed1a78b85..62e2684a9 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -35,7 +35,13 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { } def TTG_CopyAsyncOp : TTG_Op<"copy_async", - [MemoryEffects<[MemRead, MemWrite]>]> { + [MemoryEffects<[MemRead, MemWrite]>, + TypesMatchWith<"infer mask type from ptr type", + "ptr", "mask", + "getI1SameShape($_self)">, + TypesMatchWith<"infer other type from ptr type", + "ptr", "other", + "getPointeeType($_self)">]> { let summary = "copy async"; let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other, @@ -44,7 +50,7 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async", let results = (outs TT_Type:$result); - // let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)"; + let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)"; } // Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU. diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 140bce5ce..bc2b968c3 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -58,6 +58,34 @@ void TritonGPUDialect::initialize() { >(); } +namespace mlir { +namespace triton { + +// Type inference +static Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = type.dyn_cast()) + return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); + return Type(); +} + +static Type getPointeeType(Type type) { + if (auto tensorType = type.dyn_cast()) { + // Tensor of pointers + auto shape = tensorType.getShape(); + auto ptrType = tensorType.getElementType().dyn_cast(); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding()); + } else if (auto ptrType = type.dyn_cast()) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return Type(); +} + +} +} #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"