More progress on Triton=>TritonGPU conversion (works for matmul)
This commit is contained in:
@@ -41,7 +41,11 @@ TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
llvm_unreachable("Not implemented");
|
||||
printer << "<"
|
||||
// << "threadTileSize = " << getThreadTileSize()
|
||||
// << ", blockTileSize = " << getBlockTileSize()
|
||||
// << ", order = " << getOrder()
|
||||
<< ">";
|
||||
}
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
|
@@ -50,9 +50,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// materailizations
|
||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"
|
||||
<< "in: \n";
|
||||
inputs[0].dyn_cast<BlockArgument>().getOwner()->getParentOp()->getParentOp()->print(llvm::errs());
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -63,8 +60,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
});
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
llvm_unreachable("Not implemented");
|
||||
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
||||
return llvm::None;
|
||||
});
|
||||
}
|
||||
@@ -75,13 +72,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||
// TODO: we should also verify ops of TritonGPUDialect
|
||||
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||
triton::TritonDialect,
|
||||
triton::gpu::TritonGPUDialect,
|
||||
StandardOpsDialect,
|
||||
scf::SCFDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
@@ -89,14 +88,18 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
return false;
|
||||
});
|
||||
|
||||
// // We have requirements for the data layouts
|
||||
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
// return true;
|
||||
// return false;
|
||||
// });
|
||||
|
||||
// We have requirements for the data layouts
|
||||
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// TODO: we should delete this
|
||||
if (this->typeConverter.isLegal(dotOp))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user