2022-05-01 22:06:54 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
2022-07-26 17:25:03 -07:00
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
#include <algorithm>
|
2022-07-31 13:59:44 -07:00
|
|
|
#include <numeric>
|
2022-05-01 22:06:54 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
2022-06-18 21:16:45 +08:00
|
|
|
using namespace mlir::triton::gpu;
|
2022-05-01 22:06:54 +08:00
|
|
|
|
|
|
|
//
|
|
|
|
// TypeConverter
|
|
|
|
//
|
2022-07-26 17:25:03 -07:00
|
|
|
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
2022-07-31 13:59:44 -07:00
|
|
|
int numWarps)
|
|
|
|
: context(context), numWarps(numWarps) {
|
2022-05-02 21:51:00 +08:00
|
|
|
// TODO: how does MLIR pick the right conversion?
|
|
|
|
addConversion([](Type type) { return type; });
|
|
|
|
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
|
2022-07-31 13:59:44 -07:00
|
|
|
// types with encoding are already in the right format
|
|
|
|
// TODO: check for layout encodings specifically
|
|
|
|
if (tensorType.getEncoding())
|
|
|
|
return tensorType;
|
|
|
|
// pessimistic values for attributes:
|
|
|
|
// - 1 element per thread
|
|
|
|
// - order = arange(rank)
|
|
|
|
ArrayRef<int64_t> shape = tensorType.getShape();
|
|
|
|
int rank = shape.size();
|
2022-05-01 22:06:54 +08:00
|
|
|
llvm::SmallVector<unsigned> order(rank);
|
2022-07-31 13:59:44 -07:00
|
|
|
std::iota(order.begin(), order.end(), 0);
|
|
|
|
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
2022-07-31 13:59:44 -07:00
|
|
|
this->context, shape, sizePerThread, order, this->numWarps);
|
|
|
|
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
|
2022-05-01 22:06:54 +08:00
|
|
|
});
|
2022-05-04 21:50:32 +08:00
|
|
|
|
2022-06-08 16:20:07 +08:00
|
|
|
//
|
2022-11-14 10:15:53 +08:00
|
|
|
// Materializations
|
2022-06-08 16:20:07 +08:00
|
|
|
//
|
|
|
|
// This will be called when (newArgType != origArgType)
|
|
|
|
// This will create newArg, and map(origArg, newArg)
|
2022-07-26 17:25:03 -07:00
|
|
|
addArgumentMaterialization([&](OpBuilder &builder,
|
|
|
|
RankedTensorType tensorType, ValueRange inputs,
|
|
|
|
Location loc) {
|
2022-10-21 16:52:15 -07:00
|
|
|
llvm_unreachable("Argument rematerialization not implemented");
|
2022-05-04 21:50:32 +08:00
|
|
|
return llvm::None;
|
|
|
|
});
|
2022-06-08 16:20:07 +08:00
|
|
|
|
|
|
|
// If the origValue still has live user(s), use this to
|
|
|
|
// convert origValue to newValue
|
2022-05-04 21:50:32 +08:00
|
|
|
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
2022-07-26 17:25:03 -07:00
|
|
|
ValueRange inputs, Location loc) {
|
2022-10-21 16:52:15 -07:00
|
|
|
llvm_unreachable("Source rematerialization not implemented");
|
2022-05-04 21:50:32 +08:00
|
|
|
return llvm::None;
|
|
|
|
});
|
2022-06-08 16:20:07 +08:00
|
|
|
|
|
|
|
// This will be called when (desiredType != newOperandType)
|
|
|
|
// where, desiredType = typeConverter->convertType(origType)
|
|
|
|
// NOTE: only for remapped values.
|
2022-05-04 21:50:32 +08:00
|
|
|
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
2022-07-26 17:25:03 -07:00
|
|
|
ValueRange inputs, Location loc) {
|
2022-07-31 13:59:44 -07:00
|
|
|
auto cast =
|
|
|
|
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
|
|
|
|
return Optional<Value>(cast.getResult());
|
|
|
|
// return Optional<Value>(cast.getResult(0));
|
|
|
|
// llvm_unreachable("Not implemented");
|
|
|
|
// return llvm::None;
|
2022-05-04 21:50:32 +08:00
|
|
|
});
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// TritonGPUConversion
|
|
|
|
//
|
2022-06-18 14:57:41 +08:00
|
|
|
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
2022-07-26 17:25:03 -07:00
|
|
|
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
2022-10-28 12:36:09 -07:00
|
|
|
: ConversionTarget(context) {
|
2022-05-09 21:19:53 +08:00
|
|
|
// TODO: we should also verify ops of TritonGPUDialect
|
|
|
|
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
// Some ops from SCF are illegal
|
2022-07-26 17:25:03 -07:00
|
|
|
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
|
|
|
scf::ReduceReturnOp>();
|
2022-05-01 22:06:54 +08:00
|
|
|
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect,
|
|
|
|
triton::TritonDialect, StandardOpsDialect,
|
|
|
|
scf::SCFDialect>([&](Operation *op) {
|
|
|
|
if (typeConverter.isLegal(op))
|
|
|
|
return true;
|
|
|
|
return false;
|
|
|
|
});
|
2022-05-09 21:19:53 +08:00
|
|
|
|
2022-06-07 19:34:59 +08:00
|
|
|
// We have requirements for the data layouts
|
2022-10-28 12:36:09 -07:00
|
|
|
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
2022-07-26 17:25:03 -07:00
|
|
|
Attribute aEncoding =
|
|
|
|
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
|
|
|
Attribute bEncoding =
|
|
|
|
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
2022-11-10 13:57:27 +08:00
|
|
|
if (aEncoding && aEncoding.isa<triton::gpu::DotOperandEncodingAttr>() &&
|
|
|
|
bEncoding && bEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
2022-06-07 19:34:59 +08:00
|
|
|
return true;
|
|
|
|
return false;
|
|
|
|
});
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
}
|