Files
triton/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp

103 lines
3.8 KiB
C++
Raw Normal View History

#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
2022-07-26 17:25:03 -07:00
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
using namespace mlir;
2022-06-18 21:16:45 +08:00
using namespace mlir::triton::gpu;
//
// TypeConverter
//
2022-07-26 17:25:03 -07:00
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numWarps)
: context(context), numWarps(numWarps) {
2022-05-02 21:51:00 +08:00
// TODO: how does MLIR pick the right conversion?
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
// types with encoding are already in the right format
// TODO: check for layout encodings specifically
if (tensorType.getEncoding())
return tensorType;
// pessimistic values for attributes:
// - 1 element per thread
// - order = arange(rank)
ArrayRef<int64_t> shape = tensorType.getShape();
int rank = shape.size();
llvm::SmallVector<unsigned> order(rank);
std::iota(order.begin(), order.end(), 0);
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
this->context, shape, sizePerThread, order, this->numWarps);
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
});
2022-05-04 21:50:32 +08:00
//
2022-05-04 21:50:32 +08:00
// materailizations
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
2022-07-26 17:25:03 -07:00
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
2022-05-04 21:50:32 +08:00
llvm_unreachable("Not implemented");
return llvm::None;
});
// If the origValue still has live user(s), use this to
// convert origValue to newValue
2022-05-04 21:50:32 +08:00
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
2022-07-26 17:25:03 -07:00
ValueRange inputs, Location loc) {
2022-05-04 21:50:32 +08:00
llvm_unreachable("Not implemented");
return llvm::None;
});
// This will be called when (desiredType != newOperandType)
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
2022-05-04 21:50:32 +08:00
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
2022-07-26 17:25:03 -07:00
ValueRange inputs, Location loc) {
auto cast =
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
return Optional<Value>(cast.getResult());
// return Optional<Value>(cast.getResult(0));
// llvm_unreachable("Not implemented");
// return llvm::None;
2022-05-04 21:50:32 +08:00
});
}
//
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
2022-07-26 17:25:03 -07:00
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
2022-07-26 17:25:03 -07:00
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
2022-07-26 17:25:03 -07:00
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
StandardOpsDialect, scf::SCFDialect>(
[&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
2022-07-26 17:25:03 -07:00
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::SharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::SharedEncodingAttr>())
return true;
return false;
});
}