[FRONTEND] Expose end-to-end compile to python frontend (#58)
This commit is contained in:
7
.github/workflows/integration-tests.yml
vendored
7
.github/workflows/integration-tests.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
cd python
|
cd python
|
||||||
pip3 install -e '.[tests]'
|
pip3 install -e '.[tests]'
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run lit tests
|
||||||
run: |
|
run: |
|
||||||
cd python
|
cd python
|
||||||
LIT_TEST_DIR="build/$(ls build)/test"
|
LIT_TEST_DIR="build/$(ls build)/test"
|
||||||
@@ -57,3 +57,8 @@ jobs:
|
|||||||
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
|
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
|
||||||
fi
|
fi
|
||||||
lit -v "$LIT_TEST_DIR"
|
lit -v "$LIT_TEST_DIR"
|
||||||
|
|
||||||
|
- name: Run python tests
|
||||||
|
run: |
|
||||||
|
cd python/tests
|
||||||
|
# pytest
|
||||||
|
@@ -69,7 +69,7 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "")
|
|||||||
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
||||||
else()
|
else()
|
||||||
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
|
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
|
||||||
set(LLVM_LIBRARIES
|
set(LLVM_LIBRARIES
|
||||||
libLLVMNVPTXCodeGen.a
|
libLLVMNVPTXCodeGen.a
|
||||||
libLLVMNVPTXDesc.a
|
libLLVMNVPTXDesc.a
|
||||||
libLLVMNVPTXInfo.a
|
libLLVMNVPTXInfo.a
|
||||||
@@ -185,11 +185,18 @@ target_link_libraries(triton
|
|||||||
TritonTransforms
|
TritonTransforms
|
||||||
TritonGPUTransforms
|
TritonGPUTransforms
|
||||||
TritonDriver
|
TritonDriver
|
||||||
|
TritonLLVMIR
|
||||||
|
TritonPTX
|
||||||
${dialect_libs}
|
${dialect_libs}
|
||||||
${conversion_libs}
|
${conversion_libs}
|
||||||
# optimizations
|
# optimizations
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
|
MLIRIR
|
||||||
|
MLIRLLVMIR
|
||||||
|
MLIRSupport
|
||||||
|
MLIRTargetLLVMIRExport
|
||||||
|
MLIRExecutionEngine
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||||
|
@@ -100,7 +100,6 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
|||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
registerAsmPrinterCLOptions();
|
registerAsmPrinterCLOptions();
|
||||||
|
|
||||||
registerMLIRContextCLOptions();
|
registerMLIRContextCLOptions();
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
||||||
|
|
||||||
@@ -118,7 +117,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
auto llvmir = TranslateLLVMToLLVMIR(&llvmContext, *module);
|
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
|
||||||
if (!llvmir) {
|
if (!llvmir) {
|
||||||
llvm::errs() << "Translate to LLVM IR failed";
|
llvm::errs() << "Translate to LLVM IR failed";
|
||||||
}
|
}
|
||||||
|
@@ -14,9 +14,14 @@ class ModuleOp;
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
|
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
||||||
|
std::unique_ptr<llvm::Module>
|
||||||
|
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||||
|
mlir::ModuleOp module);
|
||||||
|
|
||||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
35
include/triton/Target/PTX/PTXTranslation.h
Normal file
35
include/triton/Target/PTX/PTXTranslation.h
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#ifndef TRITON_TARGET_PTXTRANSLATION_H
|
||||||
|
#define TRITON_TARGET_PTXTRANSLATION_H
|
||||||
|
|
||||||
|
#include "triton/driver/dispatch.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class ModuleOp;
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
|
||||||
|
int res;
|
||||||
|
driver::dispatch::cuDeviceGetAttribute(&res, attr, device);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
|
||||||
|
std::string *ptxasPath);
|
||||||
|
|
||||||
|
// Translate TritonGPU IR to PTX code.
|
||||||
|
std::tuple<std::string, // ptx code
|
||||||
|
size_t, // PTX cc
|
||||||
|
int, // PTX version
|
||||||
|
std::string // ptxas path
|
||||||
|
>
|
||||||
|
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device);
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
|
||||||
|
#endif
|
@@ -17,6 +17,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
|
|||||||
MLIRGPUOps
|
MLIRGPUOps
|
||||||
MLIRGPUToNVVMTransforms
|
MLIRGPUToNVVMTransforms
|
||||||
MLIRGPUTransforms
|
MLIRGPUTransforms
|
||||||
|
TritonAnalysis
|
||||||
TritonIR
|
TritonIR
|
||||||
TritonGPUIR
|
TritonGPUIR
|
||||||
TritonGPUTransforms
|
TritonGPUTransforms
|
||||||
|
@@ -1 +1,2 @@
|
|||||||
add_subdirectory(LLVMIR)
|
add_subdirectory(LLVMIR)
|
||||||
|
add_subdirectory(PTX)
|
||||||
|
@@ -4,6 +4,8 @@
|
|||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||||
#include "mlir/Target/LLVMIR/Export.h"
|
#include "mlir/Target/LLVMIR/Export.h"
|
||||||
@@ -77,7 +79,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||||
auto context = module->getContext();
|
auto context = module->getContext();
|
||||||
DialectRegistry registry;
|
DialectRegistry registry;
|
||||||
registerLLVMDialectTranslation(registry);
|
registerLLVMDialectTranslation(registry);
|
||||||
@@ -114,5 +116,26 @@ TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
|||||||
return llvmModule;
|
return llvmModule;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<llvm::Module>
|
||||||
|
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||||
|
mlir::ModuleOp module) {
|
||||||
|
mlir::PassManager pm(module->getContext());
|
||||||
|
applyPassManagerCLOptions(pm);
|
||||||
|
|
||||||
|
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||||
|
|
||||||
|
if (failed(pm.run(module))) {
|
||||||
|
llvm::errs() << "Pass execution failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
||||||
|
if (!llvmir) {
|
||||||
|
llvm::errs() << "Translate to LLVM IR failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvmir;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
9
lib/Target/PTX/CMakeLists.txt
Normal file
9
lib/Target/PTX/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
add_mlir_translation_library(TritonPTX
|
||||||
|
PTXTranslation.cpp
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
TritonLLVMIR
|
||||||
|
)
|
41
lib/Target/PTX/PTXTranslation.cpp
Normal file
41
lib/Target/PTX/PTXTranslation.cpp
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#include "triton/Target/PTX/PTXTranslation.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||||
|
#include "mlir/Target/LLVMIR/Export.h"
|
||||||
|
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||||
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||||
|
#include "triton/driver/dispatch.h"
|
||||||
|
#include "triton/driver/llvm.h"
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
|
||||||
|
std::string *ptxasPath) {
|
||||||
|
CUdevice dev = (CUdevice)device;
|
||||||
|
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||||
|
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||||
|
*cc = major * 10 + minor;
|
||||||
|
*ptxasPath = driver::path_to_ptxas(*version); // assign version
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::string, size_t, int, std::string>
|
||||||
|
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device) {
|
||||||
|
int cc;
|
||||||
|
int version;
|
||||||
|
std::string ptxasPath;
|
||||||
|
getCuCCAndVersionFromDevice(device, &cc, &version, &ptxasPath);
|
||||||
|
|
||||||
|
llvm::LLVMContext ctx;
|
||||||
|
auto llModule = mlir::triton::translateTritonGPUToLLVMIR(&ctx, module);
|
||||||
|
auto ptxCode = driver::llir_to_ptx(llModule.get(), cc, version);
|
||||||
|
return std::make_tuple(ptxCode, cc, version, ptxasPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace triton
|
@@ -150,6 +150,7 @@ setup(
|
|||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
],
|
],
|
||||||
|
test_suite="tests",
|
||||||
extras_require={
|
extras_require={
|
||||||
"tests": [
|
"tests": [
|
||||||
"autopep8",
|
"autopep8",
|
||||||
|
@@ -1,6 +1,4 @@
|
|||||||
// #include "triton/codegen/pass.h"
|
#include "triton/driver/error.h"
|
||||||
// #include "triton/codegen/target.h"
|
|
||||||
#include "triton/driver/error.h"
|
|
||||||
#include "triton/driver/llvm.h"
|
#include "triton/driver/llvm.h"
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
@@ -17,14 +15,15 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||||
|
#include "triton/Target/PTX/PTXTranslation.h"
|
||||||
|
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/Verifier.h"
|
#include "llvm/IR/Verifier.h"
|
||||||
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include "Python.h"
|
#include <Python.h>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <pybind11/buffer_info.h>
|
#include <pybind11/buffer_info.h>
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
@@ -40,21 +39,7 @@ namespace py = pybind11;
|
|||||||
// namespace ir = triton::ir;
|
// namespace ir = triton::ir;
|
||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
|
|
||||||
/*****************************************************************************/
|
using triton::cuGetInfo;
|
||||||
/* Python bindings for triton::driver */
|
|
||||||
/*****************************************************************************/
|
|
||||||
// information query
|
|
||||||
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
|
|
||||||
int res;
|
|
||||||
drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <hipDeviceAttribute_t attr> int hipGetInfo(hipDevice_t device) {
|
|
||||||
int res;
|
|
||||||
drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum backend_t {
|
enum backend_t {
|
||||||
HOST,
|
HOST,
|
||||||
@@ -100,18 +85,6 @@ void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
|||||||
(CUstream)stream, nullptr, config);
|
(CUstream)stream, nullptr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
void hip_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
|
||||||
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
|
|
||||||
uint64_t block_1, uint64_t block_2, void *args_ptr,
|
|
||||||
size_t args_size, int64_t shared_mem) {
|
|
||||||
void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
|
|
||||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
|
||||||
HIP_LAUNCH_PARAM_END};
|
|
||||||
drv::dispatch::hipModuleLaunchKernel(
|
|
||||||
(hipFunction_t)kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2,
|
|
||||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
long pow2_divisor(long N) {
|
long pow2_divisor(long N) {
|
||||||
if (N % 16 == 0)
|
if (N % 16 == 0)
|
||||||
return 16;
|
return 16;
|
||||||
@@ -381,8 +354,6 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
if (backend == CUDA)
|
if (backend == CUDA)
|
||||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
|
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
|
||||||
device);
|
device);
|
||||||
if (backend == ROCM)
|
|
||||||
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
|
|
||||||
return -1;
|
return -1;
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -432,9 +403,6 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
if (backend == CUDA)
|
if (backend == CUDA)
|
||||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
|
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
|
||||||
block_2, args_ptr, args_size, shared_mem);
|
block_2, args_ptr, args_size, shared_mem);
|
||||||
if (backend == ROCM)
|
|
||||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
|
|
||||||
block_1, block_2, args_ptr, args_size, shared_mem);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -487,120 +455,6 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
|
|||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ROCM
|
|
||||||
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string &name,
|
|
||||||
asm_map_t &asm_map,
|
|
||||||
size_t n_shared_bytes,
|
|
||||||
uint64_t dev) {
|
|
||||||
py::bytes _assembly = asm_map["hsaco"];
|
|
||||||
std::string assembly = py::cast<std::string>(_assembly);
|
|
||||||
// HSA-CO -> hipModule
|
|
||||||
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
|
|
||||||
// Handle to the kernel
|
|
||||||
hipFunction_t fun;
|
|
||||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
|
||||||
// record asm
|
|
||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------
|
|
||||||
// Compile Triton-IR to assembly
|
|
||||||
// ---------------------------------------
|
|
||||||
|
|
||||||
// // CUDA
|
|
||||||
// std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string&
|
|
||||||
// name, ir::module &ir,
|
|
||||||
// uint64_t
|
|
||||||
// device, int
|
|
||||||
// num_warps, int
|
|
||||||
// num_stages,
|
|
||||||
// asm_map_t
|
|
||||||
// &asm_map){
|
|
||||||
|
|
||||||
// int n_shared_bytes;
|
|
||||||
// py::gil_scoped_release allow_threads;
|
|
||||||
// llvm::LLVMContext ctx;
|
|
||||||
// // device properties
|
|
||||||
// CUdevice dev = (CUdevice)device;
|
|
||||||
// size_t major =
|
|
||||||
// cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); size_t minor
|
|
||||||
// = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev); size_t cc =
|
|
||||||
// major*10 + minor; int version; std::string ptxas_path =
|
|
||||||
// drv::path_to_ptxas(version);
|
|
||||||
// // Triton-IR -> NVPTX LLVM-IR
|
|
||||||
// triton::codegen::nvidia_cu_target target(cc);
|
|
||||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc,
|
|
||||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
|
||||||
// llvm::raw_string_ostream llir(tmp);
|
|
||||||
// llir << *llvm;
|
|
||||||
// llir.flush();
|
|
||||||
// asm_map["llir"] = py::cast(tmp);
|
|
||||||
// // LLVM-IR -> PTX
|
|
||||||
// std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
|
||||||
// asm_map["ptx"] = py::cast(ptx);
|
|
||||||
// // PTX -> Binary
|
|
||||||
// std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
|
||||||
// if(!cubin.empty()){
|
|
||||||
// py::bytes bytes(cubin);
|
|
||||||
// asm_map["cubin"] = bytes;
|
|
||||||
// }
|
|
||||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // HIP
|
|
||||||
// std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string&
|
|
||||||
// name, ir::module &ir,
|
|
||||||
// uint64_t
|
|
||||||
// device, int
|
|
||||||
// num_warps,
|
|
||||||
// int
|
|
||||||
// num_stages,
|
|
||||||
// asm_map_t
|
|
||||||
// &asm_map){
|
|
||||||
// llvm::LLVMContext ctx;
|
|
||||||
// // Triton-IR -> NVPTX LLVM-IR
|
|
||||||
// triton::codegen::amd_cl_target target;
|
|
||||||
// int n_shared_bytes;
|
|
||||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70,
|
|
||||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
|
||||||
// llvm::raw_string_ostream llir(tmp);
|
|
||||||
// llir << *llvm;
|
|
||||||
// llir.flush();
|
|
||||||
// asm_map["llir"] = py::cast(tmp);
|
|
||||||
// // LLVM-IR -> HSA-CO
|
|
||||||
// std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
|
|
||||||
// asm_map["hsaco"] = py::cast(path);
|
|
||||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void init_triton_codegen(py::module &&m) {
|
|
||||||
// m.def(
|
|
||||||
// "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device,
|
|
||||||
// int num_warps, int num_stages) {
|
|
||||||
// std::string name = ir.get_function_list()[0]->get_name();
|
|
||||||
// // record asm as we generate
|
|
||||||
// asm_map_t asm_map;
|
|
||||||
// std::ostringstream ttir;
|
|
||||||
// ir.print(ttir);
|
|
||||||
// asm_map["ttir"] = py::cast(ttir.str());
|
|
||||||
// llvm::LLVMContext ctx;
|
|
||||||
// if(backend == CUDA)
|
|
||||||
// return cu_compile_ttir(name, ir, device, num_warps, num_stages,
|
|
||||||
// asm_map);
|
|
||||||
// if(backend == ROCM)
|
|
||||||
// return hip_compile_ttir(name, ir, device, num_warps, num_stages,
|
|
||||||
// asm_map);
|
|
||||||
// }, py::return_value_policy::take_ownership);
|
|
||||||
// m.def("load_binary", [](backend_t backend, const std::string& name,
|
|
||||||
// asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
|
||||||
// py::gil_scoped_release allow_threads;
|
|
||||||
// if(backend == CUDA)
|
|
||||||
// return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
|
||||||
// if(backend == ROCM)
|
|
||||||
// return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
|
||||||
// }, py::return_value_policy::take_ownership);
|
|
||||||
// }
|
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
/* Python bindings for triton::ir */
|
/* Python bindings for triton::ir */
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
@@ -1655,9 +1509,45 @@ void init_triton_ir(py::module &&m) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void init_translation(py::module &m) {
|
||||||
|
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
|
||||||
|
llvm::LLVMContext llvmContext;
|
||||||
|
auto llvmModule =
|
||||||
|
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||||
|
|
||||||
|
std::string str;
|
||||||
|
llvm::raw_string_ostream os(str);
|
||||||
|
llvmModule->print(os, nullptr);
|
||||||
|
os.flush();
|
||||||
|
return str;
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("translate_triton_gpu_to_ptx",
|
||||||
|
[](mlir::ModuleOp module, uint64_t device) -> std::string {
|
||||||
|
auto [ptxCode, cc, version, ptxasPath] =
|
||||||
|
triton::translateTritonGPUToPTX(module, device);
|
||||||
|
return ptxCode;
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("compile_ptx_to_cubin",
|
||||||
|
[](const std::string &ptxCode, uint64_t device) -> py::object {
|
||||||
|
py::gil_scoped_release allow_threads;
|
||||||
|
int version;
|
||||||
|
int cc;
|
||||||
|
std::string ptxasPath;
|
||||||
|
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
|
||||||
|
&ptxasPath);
|
||||||
|
|
||||||
|
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
|
||||||
|
py::bytes bytes(cubin);
|
||||||
|
return bytes;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void init_triton(py::module &m) {
|
void init_triton(py::module &m) {
|
||||||
py::module subm = m.def_submodule("triton");
|
py::module subm = m.def_submodule("triton");
|
||||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||||
|
init_translation(subm);
|
||||||
}
|
}
|
||||||
|
0
python/tests/__init__.py
Normal file
0
python/tests/__init__.py
Normal file
23
python/tests/test_compiler.py
Normal file
23
python/tests/test_compiler.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||||
|
torch.zeros([10], device=torch.device('cuda'))
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_kernel_cubin_compile():
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
cubin = triton.compile(kernel,
|
||||||
|
"*fp32,i32,i32",
|
||||||
|
device=device,
|
||||||
|
constants={"BLOCK": 256},
|
||||||
|
output="cubin")
|
||||||
|
|
||||||
|
print('cubin size:', len(cubin))
|
||||||
|
assert len(cubin) > 0
|
@@ -791,13 +791,28 @@ def optimize_tritongpu_ir(mod, num_stages):
|
|||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
|
||||||
def make_ptx(mod):
|
def make_ptx(mod, device):
|
||||||
# TODO
|
'''
|
||||||
return mod
|
Translate TritonGPU module to PTX code.
|
||||||
|
:param mod: a TritonGPU dialect module
|
||||||
|
:return: str
|
||||||
|
'''
|
||||||
|
return _triton.translate_triton_gpu_to_ptx(mod, device)
|
||||||
|
|
||||||
|
|
||||||
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
def make_cubin(ptx, device):
|
||||||
assert output in ["ttir", "ttgir", "ptx"]
|
'''
|
||||||
|
Compile TritonGPU module to cubin.
|
||||||
|
:param ptx: ptx code
|
||||||
|
:param device: CUDA device
|
||||||
|
:return: str
|
||||||
|
'''
|
||||||
|
return _triton.compile_ptx_to_cubin(ptx, device)
|
||||||
|
|
||||||
|
|
||||||
|
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||||
|
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||||
|
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||||
# triton-ir
|
# triton-ir
|
||||||
module = make_triton_ir(fn, signature, constants, attributes)
|
module = make_triton_ir(fn, signature, constants, attributes)
|
||||||
if output == "ttir":
|
if output == "ttir":
|
||||||
@@ -807,7 +822,15 @@ def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num
|
|||||||
module = optimize_tritongpu_ir(module, num_stages)
|
module = optimize_tritongpu_ir(module, num_stages)
|
||||||
if output == "ttgir":
|
if output == "ttgir":
|
||||||
return module.str()
|
return module.str()
|
||||||
# ptx
|
|
||||||
|
assert device >= 0, "device should be provided."
|
||||||
|
|
||||||
|
ptx = make_ptx(module, device)
|
||||||
if output == "ptx":
|
if output == "ptx":
|
||||||
return make_ptx(module)
|
return ptx
|
||||||
|
|
||||||
|
cubin = make_cubin(ptx, device)
|
||||||
|
if output == "cubin":
|
||||||
|
return cubin
|
||||||
|
|
||||||
assert False
|
assert False
|
||||||
|
Reference in New Issue
Block a user