More progress on TritonGPU conversion

This commit is contained in:
Yan Da
2022-05-04 14:54:31 +08:00
parent 3ad7bee35e
commit b9279d2e3b
4 changed files with 48 additions and 26 deletions

View File

@@ -3,11 +3,14 @@
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def TT_BoolTensor : TensorOf<[I1]>;
class TTG_Op<string mnemonic, list<Trait> traits = []> : class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>; Op<TritonGPU_Dialect, mnemonic, traits>;
@@ -28,4 +31,29 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
// def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {} // def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {}
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
def TTG_CmpIOp : TTG_Op<"cmpi"> {
let summary = "integer comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
TT_IntegerTensor:$lhs,
TT_IntegerTensor:$rhs);
let results = (outs TT_BoolTensor:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf"> {
let summary = "floating-point comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
TT_FloatTensor:$lhs,
TT_FloatTensor:$rhs);
let results = (outs TT_BoolTensor:$result);
}
#endif #endif

View File

@@ -26,15 +26,15 @@ public:
} }
}; };
template<class Op> template<class SrcOp, class DstOp>
class ArithCmpPattern : public OpConversionPattern<Op> { class ArithCmpPattern : public OpConversionPattern<SrcOp> {
public: public:
using OpConversionPattern<Op>::OpConversionPattern; using OpConversionPattern<SrcOp>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType()); Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>( DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs() op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
); );
return success(); return success();
@@ -106,8 +106,10 @@ void populateArithmeticPatternsAndLegality(
ArithBinaryPattern<arith::DivFOp>, ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>, ArithBinaryPattern<arith::RemFOp>,
// Cmp // Cmp
ArithCmpPattern<arith::CmpIOp>, // ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp> // ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
>(typeConverter, context); >(typeConverter, context);
} }

View File

@@ -13,14 +13,14 @@ namespace triton {
static Type getI1SameShape(Type type) { static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1); auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type); return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return Type(); return Type();
} }
static Type getI32SameShape(Type type) { static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32); auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type); return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
return Type(); return Type();
} }
@@ -29,7 +29,7 @@ static Type getPointerTypeFromTensor(Type type) {
Type elementType = tensorType.getElementType(); Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape(); auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1); PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType); return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
} }
return Type(); return Type();
} }

View File

@@ -53,35 +53,27 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUConversionTarget::TritonGPUConversionTarget( TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter) MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) { : ConversionTarget(context), typeConverter(typeConverter) {
addLegalDialect<triton::TritonDialect, addLegalDialect<StandardOpsDialect, scf::SCFDialect>();
StandardOpsDialect,
scf::SCFDialect>();
// Some ops from SCF are illegal // Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>(); scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) { addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
if (typeConverter.isLegal(op)) { if (typeConverter.isLegal(op))
// llvm::errs() << *op << " is dyanamically legal\n";
return true; return true;
}
// if (typeConverter.isLegal(op->getOperandTypes()))
// llvm::errs() << *op << " is illegal with legal operands\n";
// if (typeConverter.isLegal(op->getResultTypes())) {
// llvm::errs() << *op << " is illegal with legal results\n";
// llvm::errs() << "operand0: " << op->getOperand(0) << "\n"
// << "operand1: " << op->getOperand(1) << "\n";
// }
return false; return false;
}); });
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) { addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op)) if (typeConverter.isLegal(op))
return true; return true;
// llvm::errs() << *op << " is illegal\n" return false;
// << "inside ...\n" });
// << *op->getParentOp() << "\n";
addDynamicallyLegalDialect<triton::gpu::TritonGPUDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false; return false;
}); });