This branch defines three new triton_gpu operations to partially solve #87. Below is an overview: ``` %tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A> %b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A> %c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A> ``` We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
43 lines
1.3 KiB
C++
43 lines
1.3 KiB
C++
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
|
|
#include "triton/Conversion/Passes.h"
|
|
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/InitAllPasses.h"
|
|
#include "mlir/Support/MlirOptMain.h"
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestAliasPass();
|
|
void registerTestAlignmentPass();
|
|
void registerTestAllocationPass();
|
|
void registerTestMembarPass();
|
|
} // namespace test
|
|
} // namespace mlir
|
|
|
|
int main(int argc, char **argv) {
|
|
mlir::registerAllPasses();
|
|
mlir::registerTritonPasses();
|
|
mlir::registerTritonGPUPasses();
|
|
mlir::test::registerTestAliasPass();
|
|
mlir::test::registerTestAlignmentPass();
|
|
mlir::test::registerTestAllocationPass();
|
|
mlir::test::registerTestMembarPass();
|
|
mlir::triton::registerConvertTritonToTritonGPUPass();
|
|
mlir::triton::registerConvertTritonGPUToLLVMPass();
|
|
|
|
// TODO: register Triton & TritonGPU passes
|
|
mlir::DialectRegistry registry;
|
|
registry
|
|
.insert<mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
|
|
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
|
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
|
|
|
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
|
argc, argv, "Triton (GPU) optimizer driver\n", registry));
|
|
}
|