[FRONTEND] Expose end-to-end compile to python frontend (#58)
This commit is contained in:
7
.github/workflows/integration-tests.yml
vendored
7
.github/workflows/integration-tests.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
cd python
|
||||
pip3 install -e '.[tests]'
|
||||
|
||||
- name: Run tests
|
||||
- name: Run lit tests
|
||||
run: |
|
||||
cd python
|
||||
LIT_TEST_DIR="build/$(ls build)/test"
|
||||
@@ -57,3 +57,8 @@ jobs:
|
||||
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
|
||||
fi
|
||||
lit -v "$LIT_TEST_DIR"
|
||||
|
||||
- name: Run python tests
|
||||
run: |
|
||||
cd python/tests
|
||||
# pytest
|
||||
|
@@ -69,7 +69,7 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "")
|
||||
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
||||
else()
|
||||
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
|
||||
set(LLVM_LIBRARIES
|
||||
set(LLVM_LIBRARIES
|
||||
libLLVMNVPTXCodeGen.a
|
||||
libLLVMNVPTXDesc.a
|
||||
libLLVMNVPTXInfo.a
|
||||
@@ -185,11 +185,18 @@ target_link_libraries(triton
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
TritonDriver
|
||||
TritonLLVMIR
|
||||
TritonPTX
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# optimizations
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRExecutionEngine
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
@@ -100,7 +100,6 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
registerAsmPrinterCLOptions();
|
||||
|
||||
registerMLIRContextCLOptions();
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
||||
|
||||
@@ -118,7 +117,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmir = TranslateLLVMToLLVMIR(&llvmContext, *module);
|
||||
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
}
|
||||
|
@@ -14,9 +14,14 @@ class ModuleOp;
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
35
include/triton/Target/PTX/PTXTranslation.h
Normal file
35
include/triton/Target/PTX/PTXTranslation.h
Normal file
@@ -0,0 +1,35 @@
|
||||
#ifndef TRITON_TARGET_PTXTRANSLATION_H
|
||||
#define TRITON_TARGET_PTXTRANSLATION_H
|
||||
|
||||
#include "triton/driver/dispatch.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
namespace triton {
|
||||
|
||||
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
|
||||
int res;
|
||||
driver::dispatch::cuDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
|
||||
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
|
||||
std::string *ptxasPath);
|
||||
|
||||
// Translate TritonGPU IR to PTX code.
|
||||
std::tuple<std::string, // ptx code
|
||||
size_t, // PTX cc
|
||||
int, // PTX version
|
||||
std::string // ptxas path
|
||||
>
|
||||
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -17,6 +17,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
|
||||
MLIRGPUOps
|
||||
MLIRGPUToNVVMTransforms
|
||||
MLIRGPUTransforms
|
||||
TritonAnalysis
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
TritonGPUTransforms
|
||||
|
@@ -1 +1,2 @@
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(PTX)
|
||||
|
@@ -4,6 +4,8 @@
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
@@ -77,7 +79,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
auto context = module->getContext();
|
||||
DialectRegistry registry;
|
||||
registerLLVMDialectTranslation(registry);
|
||||
@@ -114,5 +116,26 @@ TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
return llvmModule;
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module->getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
llvm::errs() << "Pass execution failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
}
|
||||
|
||||
return llvmir;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
9
lib/Target/PTX/CMakeLists.txt
Normal file
9
lib/Target/PTX/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
add_mlir_translation_library(TritonPTX
|
||||
PTXTranslation.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonLLVMIR
|
||||
)
|
41
lib/Target/PTX/PTXTranslation.cpp
Normal file
41
lib/Target/PTX/PTXTranslation.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
|
||||
std::string *ptxasPath) {
|
||||
CUdevice dev = (CUdevice)device;
|
||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||
*cc = major * 10 + minor;
|
||||
*ptxasPath = driver::path_to_ptxas(*version); // assign version
|
||||
}
|
||||
|
||||
std::tuple<std::string, size_t, int, std::string>
|
||||
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device) {
|
||||
int cc;
|
||||
int version;
|
||||
std::string ptxasPath;
|
||||
getCuCCAndVersionFromDevice(device, &cc, &version, &ptxasPath);
|
||||
|
||||
llvm::LLVMContext ctx;
|
||||
auto llModule = mlir::triton::translateTritonGPUToLLVMIR(&ctx, module);
|
||||
auto ptxCode = driver::llir_to_ptx(llModule.get(), cc, version);
|
||||
return std::make_tuple(ptxCode, cc, version, ptxasPath);
|
||||
}
|
||||
|
||||
} // namespace triton
|
@@ -150,6 +150,7 @@ setup(
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
],
|
||||
test_suite="tests",
|
||||
extras_require={
|
||||
"tests": [
|
||||
"autopep8",
|
||||
|
@@ -1,6 +1,4 @@
|
||||
// #include "triton/codegen/pass.h"
|
||||
// #include "triton/codegen/target.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
@@ -17,14 +15,15 @@
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "Python.h"
|
||||
#include <Python.h>
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -40,21 +39,7 @@ namespace py = pybind11;
|
||||
// namespace ir = triton::ir;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::driver */
|
||||
/*****************************************************************************/
|
||||
// information query
|
||||
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
|
||||
int res;
|
||||
drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <hipDeviceAttribute_t attr> int hipGetInfo(hipDevice_t device) {
|
||||
int res;
|
||||
drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
using triton::cuGetInfo;
|
||||
|
||||
enum backend_t {
|
||||
HOST,
|
||||
@@ -100,18 +85,6 @@ void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
(CUstream)stream, nullptr, config);
|
||||
}
|
||||
|
||||
void hip_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
|
||||
uint64_t block_1, uint64_t block_2, void *args_ptr,
|
||||
size_t args_size, int64_t shared_mem) {
|
||||
void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
HIP_LAUNCH_PARAM_END};
|
||||
drv::dispatch::hipModuleLaunchKernel(
|
||||
(hipFunction_t)kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
}
|
||||
|
||||
long pow2_divisor(long N) {
|
||||
if (N % 16 == 0)
|
||||
return 16;
|
||||
@@ -381,8 +354,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
|
||||
device);
|
||||
if (backend == ROCM)
|
||||
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
@@ -432,9 +403,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
if (backend == CUDA)
|
||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
|
||||
block_2, args_ptr, args_size, shared_mem);
|
||||
if (backend == ROCM)
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
|
||||
block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -487,120 +455,6 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
// ROCM
|
||||
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string &name,
|
||||
asm_map_t &asm_map,
|
||||
size_t n_shared_bytes,
|
||||
uint64_t dev) {
|
||||
py::bytes _assembly = asm_map["hsaco"];
|
||||
std::string assembly = py::cast<std::string>(_assembly);
|
||||
// HSA-CO -> hipModule
|
||||
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
|
||||
// Handle to the kernel
|
||||
hipFunction_t fun;
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// record asm
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
// ---------------------------------------
|
||||
// Compile Triton-IR to assembly
|
||||
// ---------------------------------------
|
||||
|
||||
// // CUDA
|
||||
// std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string&
|
||||
// name, ir::module &ir,
|
||||
// uint64_t
|
||||
// device, int
|
||||
// num_warps, int
|
||||
// num_stages,
|
||||
// asm_map_t
|
||||
// &asm_map){
|
||||
|
||||
// int n_shared_bytes;
|
||||
// py::gil_scoped_release allow_threads;
|
||||
// llvm::LLVMContext ctx;
|
||||
// // device properties
|
||||
// CUdevice dev = (CUdevice)device;
|
||||
// size_t major =
|
||||
// cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); size_t minor
|
||||
// = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev); size_t cc =
|
||||
// major*10 + minor; int version; std::string ptxas_path =
|
||||
// drv::path_to_ptxas(version);
|
||||
// // Triton-IR -> NVPTX LLVM-IR
|
||||
// triton::codegen::nvidia_cu_target target(cc);
|
||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc,
|
||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
||||
// llvm::raw_string_ostream llir(tmp);
|
||||
// llir << *llvm;
|
||||
// llir.flush();
|
||||
// asm_map["llir"] = py::cast(tmp);
|
||||
// // LLVM-IR -> PTX
|
||||
// std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
||||
// asm_map["ptx"] = py::cast(ptx);
|
||||
// // PTX -> Binary
|
||||
// std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
||||
// if(!cubin.empty()){
|
||||
// py::bytes bytes(cubin);
|
||||
// asm_map["cubin"] = bytes;
|
||||
// }
|
||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
// }
|
||||
|
||||
// // HIP
|
||||
// std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string&
|
||||
// name, ir::module &ir,
|
||||
// uint64_t
|
||||
// device, int
|
||||
// num_warps,
|
||||
// int
|
||||
// num_stages,
|
||||
// asm_map_t
|
||||
// &asm_map){
|
||||
// llvm::LLVMContext ctx;
|
||||
// // Triton-IR -> NVPTX LLVM-IR
|
||||
// triton::codegen::amd_cl_target target;
|
||||
// int n_shared_bytes;
|
||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70,
|
||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
||||
// llvm::raw_string_ostream llir(tmp);
|
||||
// llir << *llvm;
|
||||
// llir.flush();
|
||||
// asm_map["llir"] = py::cast(tmp);
|
||||
// // LLVM-IR -> HSA-CO
|
||||
// std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
|
||||
// asm_map["hsaco"] = py::cast(path);
|
||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
// }
|
||||
|
||||
// void init_triton_codegen(py::module &&m) {
|
||||
// m.def(
|
||||
// "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device,
|
||||
// int num_warps, int num_stages) {
|
||||
// std::string name = ir.get_function_list()[0]->get_name();
|
||||
// // record asm as we generate
|
||||
// asm_map_t asm_map;
|
||||
// std::ostringstream ttir;
|
||||
// ir.print(ttir);
|
||||
// asm_map["ttir"] = py::cast(ttir.str());
|
||||
// llvm::LLVMContext ctx;
|
||||
// if(backend == CUDA)
|
||||
// return cu_compile_ttir(name, ir, device, num_warps, num_stages,
|
||||
// asm_map);
|
||||
// if(backend == ROCM)
|
||||
// return hip_compile_ttir(name, ir, device, num_warps, num_stages,
|
||||
// asm_map);
|
||||
// }, py::return_value_policy::take_ownership);
|
||||
// m.def("load_binary", [](backend_t backend, const std::string& name,
|
||||
// asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
// py::gil_scoped_release allow_threads;
|
||||
// if(backend == CUDA)
|
||||
// return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// if(backend == ROCM)
|
||||
// return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// }, py::return_value_policy::take_ownership);
|
||||
// }
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
/*****************************************************************************/
|
||||
@@ -1655,9 +1509,45 @@ void init_triton_ir(py::module &&m) {
|
||||
});
|
||||
}
|
||||
|
||||
void init_translation(py::module &m) {
|
||||
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
llvmModule->print(os, nullptr);
|
||||
os.flush();
|
||||
return str;
|
||||
});
|
||||
|
||||
m.def("translate_triton_gpu_to_ptx",
|
||||
[](mlir::ModuleOp module, uint64_t device) -> std::string {
|
||||
auto [ptxCode, cc, version, ptxasPath] =
|
||||
triton::translateTritonGPUToPTX(module, device);
|
||||
return ptxCode;
|
||||
});
|
||||
|
||||
m.def("compile_ptx_to_cubin",
|
||||
[](const std::string &ptxCode, uint64_t device) -> py::object {
|
||||
py::gil_scoped_release allow_threads;
|
||||
int version;
|
||||
int cc;
|
||||
std::string ptxasPath;
|
||||
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
|
||||
&ptxasPath);
|
||||
|
||||
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
});
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_translation(subm);
|
||||
}
|
||||
|
0
python/tests/__init__.py
Normal file
0
python/tests/__init__.py
Normal file
23
python/tests/test_compiler.py
Normal file
23
python/tests/test_compiler.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
|
||||
|
||||
def test_empty_kernel_cubin_compile():
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
cubin = triton.compile(kernel,
|
||||
"*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256},
|
||||
output="cubin")
|
||||
|
||||
print('cubin size:', len(cubin))
|
||||
assert len(cubin) > 0
|
@@ -791,13 +791,28 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod):
|
||||
# TODO
|
||||
return mod
|
||||
def make_ptx(mod, device):
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return: str
|
||||
'''
|
||||
return _triton.translate_triton_gpu_to_ptx(mod, device)
|
||||
|
||||
|
||||
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
assert output in ["ttir", "ttgir", "ptx"]
|
||||
def make_cubin(ptx, device):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param device: CUDA device
|
||||
:return: str
|
||||
'''
|
||||
return _triton.compile_ptx_to_cubin(ptx, device)
|
||||
|
||||
|
||||
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
if output == "ttir":
|
||||
@@ -807,7 +822,15 @@ def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
# ptx
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
|
||||
ptx = make_ptx(module, device)
|
||||
if output == "ptx":
|
||||
return make_ptx(module)
|
||||
return ptx
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
if output == "cubin":
|
||||
return cubin
|
||||
|
||||
assert False
|
||||
|
Reference in New Issue
Block a user