ConstantOp conversion pattern
This commit is contained in:
@@ -11,6 +11,10 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
|||||||
|
|
||||||
def TT_BoolTensor : TensorOf<[I1]>;
|
def TT_BoolTensor : TensorOf<[I1]>;
|
||||||
|
|
||||||
|
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
|
||||||
|
def TT_IntegerLike : AnyTypeOf<[TT_Int, TT_IntegerTensor]>;
|
||||||
|
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||||
|
|
||||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
@@ -38,10 +42,10 @@ def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
|||||||
let description = [{}];
|
let description = [{}];
|
||||||
|
|
||||||
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||||
TT_IntegerTensor:$lhs,
|
TT_IntegerLike:$lhs,
|
||||||
TT_IntegerTensor:$rhs);
|
TT_IntegerLike:$rhs);
|
||||||
|
|
||||||
let results = (outs TT_BoolTensor:$result);
|
let results = (outs TT_BoolLike:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
||||||
@@ -50,10 +54,10 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
|||||||
let description = [{}];
|
let description = [{}];
|
||||||
|
|
||||||
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
|
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
|
||||||
TT_FloatTensor:$lhs,
|
TT_FloatLike:$lhs,
|
||||||
TT_FloatTensor:$rhs);
|
TT_FloatLike:$rhs);
|
||||||
|
|
||||||
let results = (outs TT_BoolTensor:$result);
|
let results = (outs TT_BoolLike:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -41,6 +41,20 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||||
|
op, retType, adaptor.getValue()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ConvertArithmeticOp: public ConversionPattern {
|
class ConvertArithmeticOp: public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||||
@@ -75,7 +89,8 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
// );
|
// );
|
||||||
// Rewrite rule
|
// Rewrite rule
|
||||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||||
patterns.add<ArithBinaryPattern<arith::AddIOp>,
|
patterns.add<ArithConstantPattern,
|
||||||
|
ArithBinaryPattern<arith::AddIOp>,
|
||||||
ArithBinaryPattern<arith::SubIOp>,
|
ArithBinaryPattern<arith::SubIOp>,
|
||||||
ArithBinaryPattern<arith::MulIOp>,
|
ArithBinaryPattern<arith::MulIOp>,
|
||||||
ArithBinaryPattern<arith::DivUIOp>,
|
ArithBinaryPattern<arith::DivUIOp>,
|
||||||
@@ -106,10 +121,9 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
ArithBinaryPattern<arith::DivFOp>,
|
ArithBinaryPattern<arith::DivFOp>,
|
||||||
ArithBinaryPattern<arith::RemFOp>,
|
ArithBinaryPattern<arith::RemFOp>,
|
||||||
// Cmp
|
// Cmp
|
||||||
// ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
|
||||||
// ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
|
||||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||||
|
// Cast Ops
|
||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,7 +219,7 @@ public:
|
|||||||
ModuleOp mod = getOperation();
|
ModuleOp mod = getOperation();
|
||||||
// int numThreads = mod.getAttr();
|
// int numThreads = mod.getAttr();
|
||||||
// type converter
|
// type converter
|
||||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/128);
|
TritonGPUTypeConverter typeConverter(context, /*numThreads*/32);
|
||||||
TritonGPUConversionTarget target(*context, typeConverter);
|
TritonGPUConversionTarget target(*context, typeConverter);
|
||||||
// rewrite patterns
|
// rewrite patterns
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
@@ -90,19 +90,14 @@ mod, ctx = matmul_kernel.compile_to_ttir(
|
|||||||
a.stride(0), a.stride(1),
|
a.stride(0), a.stride(1),
|
||||||
b.stride(0), b.stride(1),
|
b.stride(0), b.stride(1),
|
||||||
c.stride(0), c.stride(1),
|
c.stride(0), c.stride(1),
|
||||||
64, 64, 32,
|
128, 128, 128,
|
||||||
8, grid=(2,)
|
8, grid=(2,)
|
||||||
)
|
)
|
||||||
|
|
||||||
# assert mod.verify()
|
assert mod.verify()
|
||||||
# mod.dump()
|
mod.dump()
|
||||||
|
|
||||||
pm = _triton.ir.pass_manager(ctx)
|
mod = matmul_kernel.compile_ttir_to_llir(mod, ctx)
|
||||||
pm.add_inliner_pass()
|
|
||||||
pm.add_triton_combine_pass()
|
|
||||||
pm.add_canonicalizer_pass()
|
|
||||||
pm.add_convert_triton_to_tritongpu_pass()
|
|
||||||
pm.run(mod)
|
|
||||||
|
|
||||||
# assert mod.verify()
|
assert mod.verify()
|
||||||
# mod.dump()
|
mod.dump()
|
||||||
|
Reference in New Issue
Block a user