More conversion patterns
This commit is contained in:
@@ -19,8 +19,23 @@ public:
|
|||||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<Op>(
|
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||||
op.getOperation(), retType, op.getLhs(), op.getRhs()
|
op, retType, adaptor.getOperands()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class Op>
|
||||||
|
class ArithCmpPattern : public OpConversionPattern<Op> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<Op>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
|
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||||
|
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -51,13 +66,13 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
// non-null encoding
|
// non-null encoding
|
||||||
// --------------
|
// --------------
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
// Legality rule
|
// // Legality rule
|
||||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||||
// TODO: check above rule here
|
// // TODO: check above rule here
|
||||||
[](Operation *op){
|
// [](Operation *op){
|
||||||
return true;
|
// return true;
|
||||||
}
|
// }
|
||||||
);
|
// );
|
||||||
// Rewrite rule
|
// Rewrite rule
|
||||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||||
patterns.add<ArithBinaryPattern<arith::AddIOp>,
|
patterns.add<ArithBinaryPattern<arith::AddIOp>,
|
||||||
@@ -89,7 +104,10 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
// Floating point
|
// Floating point
|
||||||
ArithBinaryPattern<arith::MulFOp>,
|
ArithBinaryPattern<arith::MulFOp>,
|
||||||
ArithBinaryPattern<arith::DivFOp>,
|
ArithBinaryPattern<arith::DivFOp>,
|
||||||
ArithBinaryPattern<arith::RemFOp>
|
ArithBinaryPattern<arith::RemFOp>,
|
||||||
|
// Cmp
|
||||||
|
ArithCmpPattern<arith::CmpIOp>,
|
||||||
|
ArithCmpPattern<arith::CmpFOp>
|
||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,9 +121,8 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
|
|||||||
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||||
op.getOperation(), retType, op.start(), op.end()
|
op, retType, adaptor.start(), adaptor.end()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -118,7 +135,7 @@ struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp>
|
|||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
||||||
op.getOperation(), retType, op.src()
|
op, retType, adaptor.src()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -131,7 +148,7 @@ struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
|
|||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
||||||
op.getOperation(), retType, op.ptr(), op.offset()
|
op, retType, adaptor.getOperands()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -144,8 +161,21 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
|||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||||
op.getOperation(), retType,
|
op, retType,
|
||||||
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
|
adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||||
|
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||||
|
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||||
|
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -158,7 +188,8 @@ void populateTritonPatterns(
|
|||||||
patterns.add<TritonMakeRangePattern,
|
patterns.add<TritonMakeRangePattern,
|
||||||
TritonBroadcastPattern,
|
TritonBroadcastPattern,
|
||||||
TritonGEPPattern,
|
TritonGEPPattern,
|
||||||
TritonLoadPattern
|
TritonLoadPattern,
|
||||||
|
TritonStorePattern
|
||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -19,7 +19,11 @@ TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
llvm_unreachable("Not implemented");
|
printer << "<"
|
||||||
|
<< "threadTileSize = " << getThreadTileSize()
|
||||||
|
<< ", blockTileSize = " << getBlockTileSize()
|
||||||
|
<< ", order = " << getOrder()
|
||||||
|
<< ">";
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Attribute
|
mlir::Attribute
|
||||||
|
@@ -26,7 +26,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
assert(numElements > numThreads);
|
assert(numElements > numThreads);
|
||||||
assert(numElements % numThreads == 0);
|
assert(numElements % numThreads == 0);
|
||||||
|
|
||||||
// assert no encoding?
|
// or assert no encoding?
|
||||||
|
|
||||||
// Now we assume:
|
// Now we assume:
|
||||||
// contiguous = 1, order = 0, 1, 2, ...,
|
// contiguous = 1, order = 0, 1, 2, ...,
|
||||||
@@ -62,14 +62,26 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||||
|
|
||||||
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
||||||
if (typeConverter.isLegal(op))
|
if (typeConverter.isLegal(op)) {
|
||||||
|
// llvm::errs() << *op << " is dyanamically legal\n";
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
// if (typeConverter.isLegal(op->getOperandTypes()))
|
||||||
|
// llvm::errs() << *op << " is illegal with legal operands\n";
|
||||||
|
// if (typeConverter.isLegal(op->getResultTypes())) {
|
||||||
|
// llvm::errs() << *op << " is illegal with legal results\n";
|
||||||
|
// llvm::errs() << "operand0: " << op->getOperand(0) << "\n"
|
||||||
|
// << "operand1: " << op->getOperand(1) << "\n";
|
||||||
|
// }
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
||||||
if (typeConverter.isLegal(op))
|
if (typeConverter.isLegal(op))
|
||||||
return true;
|
return true;
|
||||||
|
// llvm::errs() << *op << " is illegal\n"
|
||||||
|
// << "inside ...\n"
|
||||||
|
// << *op->getParentOp() << "\n";
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@@ -41,5 +41,7 @@ z = torch.empty_like(x)
|
|||||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||||
# print(add_kernel.annotations)
|
# print(add_kernel.annotations)
|
||||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
|
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
|
||||||
|
assert mod.verify()
|
||||||
|
mod.dump()
|
||||||
mod = add_kernel.compile_ttir_to_llir(mod, ctx)
|
mod = add_kernel.compile_ttir_to_llir(mod, ctx)
|
||||||
mod.dump()
|
mod.dump()
|
||||||
|
Reference in New Issue
Block a user