More progress on Triton=>TritonGPU conversion (works for matmul)

This commit is contained in:
Yan Da
2022-05-09 21:19:53 +08:00
parent 0c5319eed9
commit 96876a46d1
3 changed files with 64 additions and 32 deletions

View File

@@ -41,7 +41,11 @@ TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
}
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
printer << "<"
// << "threadTileSize = " << getThreadTileSize()
// << ", blockTileSize = " << getBlockTileSize()
// << ", order = " << getOrder()
<< ">";
}
void TritonGPUDialect::initialize() {

View File

@@ -50,9 +50,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// materailizations
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"
<< "in: \n";
inputs[0].dyn_cast<BlockArgument>().getOwner()->getParentOp()->getParentOp()->print(llvm::errs());
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -63,8 +60,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
});
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
llvm_unreachable("Not implemented");
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
return llvm::None;
});
}
@@ -75,13 +72,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect,
triton::TritonDialect,
triton::gpu::TritonGPUDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
@@ -89,14 +88,18 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
return false;
});
// // We have requirements for the data layouts
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
// return true;
// return false;
// });
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// TODO: we should delete this
if (this->typeConverter.isLegal(dotOp))
return true;
return false;
});
}