More progress on TritonGPU conversion
This commit is contained in:
@@ -3,11 +3,14 @@
|
|||||||
|
|
||||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||||
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
||||||
|
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
|
||||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||||
|
|
||||||
|
def TT_BoolTensor : TensorOf<[I1]>;
|
||||||
|
|
||||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
@@ -28,4 +31,29 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
|||||||
|
|
||||||
// def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {}
|
// def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {}
|
||||||
|
|
||||||
|
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||||
|
def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
||||||
|
let summary = "integer comparison operation";
|
||||||
|
|
||||||
|
let description = [{}];
|
||||||
|
|
||||||
|
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||||
|
TT_IntegerTensor:$lhs,
|
||||||
|
TT_IntegerTensor:$rhs);
|
||||||
|
|
||||||
|
let results = (outs TT_BoolTensor:$result);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
||||||
|
let summary = "floating-point comparison operation";
|
||||||
|
|
||||||
|
let description = [{}];
|
||||||
|
|
||||||
|
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
|
||||||
|
TT_FloatTensor:$lhs,
|
||||||
|
TT_FloatTensor:$rhs);
|
||||||
|
|
||||||
|
let results = (outs TT_BoolTensor:$result);
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -26,15 +26,15 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class Op>
|
template<class SrcOp, class DstOp>
|
||||||
class ArithCmpPattern : public OpConversionPattern<Op> {
|
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<Op>::OpConversionPattern;
|
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
Op res = rewriter.replaceOpWithNewOp<Op>(
|
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
|
||||||
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
@@ -106,8 +106,10 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
ArithBinaryPattern<arith::DivFOp>,
|
ArithBinaryPattern<arith::DivFOp>,
|
||||||
ArithBinaryPattern<arith::RemFOp>,
|
ArithBinaryPattern<arith::RemFOp>,
|
||||||
// Cmp
|
// Cmp
|
||||||
ArithCmpPattern<arith::CmpIOp>,
|
// ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||||
ArithCmpPattern<arith::CmpFOp>
|
// ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||||
|
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||||
|
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -13,14 +13,14 @@ namespace triton {
|
|||||||
static Type getI1SameShape(Type type) {
|
static Type getI1SameShape(Type type) {
|
||||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||||
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type getI32SameShape(Type type) {
|
static Type getI32SameShape(Type type) {
|
||||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||||
return RankedTensorType::get(tensorType.getShape(), i32Type);
|
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ static Type getPointerTypeFromTensor(Type type) {
|
|||||||
Type elementType = tensorType.getElementType();
|
Type elementType = tensorType.getElementType();
|
||||||
auto shape = tensorType.getShape();
|
auto shape = tensorType.getShape();
|
||||||
PointerType ptrType = PointerType::get(elementType, 1);
|
PointerType ptrType = PointerType::get(elementType, 1);
|
||||||
return RankedTensorType::get(shape, ptrType);
|
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
|
||||||
}
|
}
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
@@ -53,35 +53,27 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||||
addLegalDialect<triton::TritonDialect,
|
addLegalDialect<StandardOpsDialect, scf::SCFDialect>();
|
||||||
StandardOpsDialect,
|
|
||||||
scf::SCFDialect>();
|
|
||||||
|
|
||||||
// Some ops from SCF are illegal
|
// Some ops from SCF are illegal
|
||||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||||
|
|
||||||
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
||||||
if (typeConverter.isLegal(op)) {
|
if (typeConverter.isLegal(op))
|
||||||
// llvm::errs() << *op << " is dyanamically legal\n";
|
|
||||||
return true;
|
return true;
|
||||||
}
|
|
||||||
// if (typeConverter.isLegal(op->getOperandTypes()))
|
|
||||||
// llvm::errs() << *op << " is illegal with legal operands\n";
|
|
||||||
// if (typeConverter.isLegal(op->getResultTypes())) {
|
|
||||||
// llvm::errs() << *op << " is illegal with legal results\n";
|
|
||||||
// llvm::errs() << "operand0: " << op->getOperand(0) << "\n"
|
|
||||||
// << "operand1: " << op->getOperand(1) << "\n";
|
|
||||||
// }
|
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
||||||
if (typeConverter.isLegal(op))
|
if (typeConverter.isLegal(op))
|
||||||
return true;
|
return true;
|
||||||
// llvm::errs() << *op << " is illegal\n"
|
return false;
|
||||||
// << "inside ...\n"
|
});
|
||||||
// << *op->getParentOp() << "\n";
|
|
||||||
|
addDynamicallyLegalDialect<triton::gpu::TritonGPUDialect>([&](Operation *op) {
|
||||||
|
if (typeConverter.isLegal(op))
|
||||||
|
return true;
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user