[FRONTEND] Expose end-to-end compile to python frontend (#58)
This commit is contained in:
@@ -150,6 +150,7 @@ setup(
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
],
|
||||
test_suite="tests",
|
||||
extras_require={
|
||||
"tests": [
|
||||
"autopep8",
|
||||
|
@@ -1,6 +1,4 @@
|
||||
// #include "triton/codegen/pass.h"
|
||||
// #include "triton/codegen/target.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
@@ -17,14 +15,15 @@
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "Python.h"
|
||||
#include <Python.h>
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -40,21 +39,7 @@ namespace py = pybind11;
|
||||
// namespace ir = triton::ir;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::driver */
|
||||
/*****************************************************************************/
|
||||
// information query
|
||||
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
|
||||
int res;
|
||||
drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <hipDeviceAttribute_t attr> int hipGetInfo(hipDevice_t device) {
|
||||
int res;
|
||||
drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
using triton::cuGetInfo;
|
||||
|
||||
enum backend_t {
|
||||
HOST,
|
||||
@@ -100,18 +85,6 @@ void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
(CUstream)stream, nullptr, config);
|
||||
}
|
||||
|
||||
void hip_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
|
||||
uint64_t block_1, uint64_t block_2, void *args_ptr,
|
||||
size_t args_size, int64_t shared_mem) {
|
||||
void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
HIP_LAUNCH_PARAM_END};
|
||||
drv::dispatch::hipModuleLaunchKernel(
|
||||
(hipFunction_t)kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
}
|
||||
|
||||
long pow2_divisor(long N) {
|
||||
if (N % 16 == 0)
|
||||
return 16;
|
||||
@@ -381,8 +354,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
|
||||
device);
|
||||
if (backend == ROCM)
|
||||
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
@@ -432,9 +403,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
if (backend == CUDA)
|
||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
|
||||
block_2, args_ptr, args_size, shared_mem);
|
||||
if (backend == ROCM)
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
|
||||
block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -487,120 +455,6 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
// ROCM
|
||||
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string &name,
|
||||
asm_map_t &asm_map,
|
||||
size_t n_shared_bytes,
|
||||
uint64_t dev) {
|
||||
py::bytes _assembly = asm_map["hsaco"];
|
||||
std::string assembly = py::cast<std::string>(_assembly);
|
||||
// HSA-CO -> hipModule
|
||||
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
|
||||
// Handle to the kernel
|
||||
hipFunction_t fun;
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// record asm
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
// ---------------------------------------
|
||||
// Compile Triton-IR to assembly
|
||||
// ---------------------------------------
|
||||
|
||||
// // CUDA
|
||||
// std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string&
|
||||
// name, ir::module &ir,
|
||||
// uint64_t
|
||||
// device, int
|
||||
// num_warps, int
|
||||
// num_stages,
|
||||
// asm_map_t
|
||||
// &asm_map){
|
||||
|
||||
// int n_shared_bytes;
|
||||
// py::gil_scoped_release allow_threads;
|
||||
// llvm::LLVMContext ctx;
|
||||
// // device properties
|
||||
// CUdevice dev = (CUdevice)device;
|
||||
// size_t major =
|
||||
// cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); size_t minor
|
||||
// = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev); size_t cc =
|
||||
// major*10 + minor; int version; std::string ptxas_path =
|
||||
// drv::path_to_ptxas(version);
|
||||
// // Triton-IR -> NVPTX LLVM-IR
|
||||
// triton::codegen::nvidia_cu_target target(cc);
|
||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc,
|
||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
||||
// llvm::raw_string_ostream llir(tmp);
|
||||
// llir << *llvm;
|
||||
// llir.flush();
|
||||
// asm_map["llir"] = py::cast(tmp);
|
||||
// // LLVM-IR -> PTX
|
||||
// std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
||||
// asm_map["ptx"] = py::cast(ptx);
|
||||
// // PTX -> Binary
|
||||
// std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
||||
// if(!cubin.empty()){
|
||||
// py::bytes bytes(cubin);
|
||||
// asm_map["cubin"] = bytes;
|
||||
// }
|
||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
// }
|
||||
|
||||
// // HIP
|
||||
// std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string&
|
||||
// name, ir::module &ir,
|
||||
// uint64_t
|
||||
// device, int
|
||||
// num_warps,
|
||||
// int
|
||||
// num_stages,
|
||||
// asm_map_t
|
||||
// &asm_map){
|
||||
// llvm::LLVMContext ctx;
|
||||
// // Triton-IR -> NVPTX LLVM-IR
|
||||
// triton::codegen::amd_cl_target target;
|
||||
// int n_shared_bytes;
|
||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70,
|
||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
||||
// llvm::raw_string_ostream llir(tmp);
|
||||
// llir << *llvm;
|
||||
// llir.flush();
|
||||
// asm_map["llir"] = py::cast(tmp);
|
||||
// // LLVM-IR -> HSA-CO
|
||||
// std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
|
||||
// asm_map["hsaco"] = py::cast(path);
|
||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
// }
|
||||
|
||||
// void init_triton_codegen(py::module &&m) {
|
||||
// m.def(
|
||||
// "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device,
|
||||
// int num_warps, int num_stages) {
|
||||
// std::string name = ir.get_function_list()[0]->get_name();
|
||||
// // record asm as we generate
|
||||
// asm_map_t asm_map;
|
||||
// std::ostringstream ttir;
|
||||
// ir.print(ttir);
|
||||
// asm_map["ttir"] = py::cast(ttir.str());
|
||||
// llvm::LLVMContext ctx;
|
||||
// if(backend == CUDA)
|
||||
// return cu_compile_ttir(name, ir, device, num_warps, num_stages,
|
||||
// asm_map);
|
||||
// if(backend == ROCM)
|
||||
// return hip_compile_ttir(name, ir, device, num_warps, num_stages,
|
||||
// asm_map);
|
||||
// }, py::return_value_policy::take_ownership);
|
||||
// m.def("load_binary", [](backend_t backend, const std::string& name,
|
||||
// asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
// py::gil_scoped_release allow_threads;
|
||||
// if(backend == CUDA)
|
||||
// return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// if(backend == ROCM)
|
||||
// return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// }, py::return_value_policy::take_ownership);
|
||||
// }
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
/*****************************************************************************/
|
||||
@@ -1655,9 +1509,45 @@ void init_triton_ir(py::module &&m) {
|
||||
});
|
||||
}
|
||||
|
||||
void init_translation(py::module &m) {
|
||||
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
llvmModule->print(os, nullptr);
|
||||
os.flush();
|
||||
return str;
|
||||
});
|
||||
|
||||
m.def("translate_triton_gpu_to_ptx",
|
||||
[](mlir::ModuleOp module, uint64_t device) -> std::string {
|
||||
auto [ptxCode, cc, version, ptxasPath] =
|
||||
triton::translateTritonGPUToPTX(module, device);
|
||||
return ptxCode;
|
||||
});
|
||||
|
||||
m.def("compile_ptx_to_cubin",
|
||||
[](const std::string &ptxCode, uint64_t device) -> py::object {
|
||||
py::gil_scoped_release allow_threads;
|
||||
int version;
|
||||
int cc;
|
||||
std::string ptxasPath;
|
||||
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
|
||||
&ptxasPath);
|
||||
|
||||
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
});
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_translation(subm);
|
||||
}
|
||||
|
0
python/tests/__init__.py
Normal file
0
python/tests/__init__.py
Normal file
23
python/tests/test_compiler.py
Normal file
23
python/tests/test_compiler.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
|
||||
|
||||
def test_empty_kernel_cubin_compile():
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
cubin = triton.compile(kernel,
|
||||
"*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256},
|
||||
output="cubin")
|
||||
|
||||
print('cubin size:', len(cubin))
|
||||
assert len(cubin) > 0
|
@@ -791,13 +791,28 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod):
|
||||
# TODO
|
||||
return mod
|
||||
def make_ptx(mod, device):
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return: str
|
||||
'''
|
||||
return _triton.translate_triton_gpu_to_ptx(mod, device)
|
||||
|
||||
|
||||
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
assert output in ["ttir", "ttgir", "ptx"]
|
||||
def make_cubin(ptx, device):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param device: CUDA device
|
||||
:return: str
|
||||
'''
|
||||
return _triton.compile_ptx_to_cubin(ptx, device)
|
||||
|
||||
|
||||
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
if output == "ttir":
|
||||
@@ -807,7 +822,15 @@ def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
# ptx
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
|
||||
ptx = make_ptx(module, device)
|
||||
if output == "ptx":
|
||||
return make_ptx(module)
|
||||
return ptx
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
if output == "cubin":
|
||||
return cubin
|
||||
|
||||
assert False
|
||||
|
Reference in New Issue
Block a user