[FRONTEND] Expose end-to-end compile to python frontend (#58)

This commit is contained in:
Yan Chunwei
2022-08-18 01:42:48 +08:00
committed by GitHub
parent 95bbac41e7
commit b1673caaf6
15 changed files with 228 additions and 165 deletions

View File

@@ -49,7 +49,7 @@ jobs:
cd python
pip3 install -e '.[tests]'
- name: Run tests
- name: Run lit tests
run: |
cd python
LIT_TEST_DIR="build/$(ls build)/test"
@@ -57,3 +57,8 @@ jobs:
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
fi
lit -v "$LIT_TEST_DIR"
- name: Run python tests
run: |
cd python/tests
# pytest

View File

@@ -69,7 +69,7 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "")
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
set(LLVM_LIBRARIES
set(LLVM_LIBRARIES
libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a
@@ -185,11 +185,18 @@ target_link_libraries(triton
TritonTransforms
TritonGPUTransforms
TritonDriver
TritonLLVMIR
TritonPTX
${dialect_libs}
${conversion_libs}
# optimizations
MLIRPass
MLIRTransforms
MLIRIR
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})

View File

@@ -100,7 +100,6 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::InitLLVM y(argc, argv);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
@@ -118,7 +117,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}
llvm::LLVMContext llvmContext;
auto llvmir = TranslateLLVMToLLVMIR(&llvmContext, *module);
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}

View File

@@ -14,9 +14,14 @@ class ModuleOp;
namespace mlir {
namespace triton {
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,35 @@
#ifndef TRITON_TARGET_PTXTRANSLATION_H
#define TRITON_TARGET_PTXTRANSLATION_H
#include "triton/driver/dispatch.h"
#include <string>
namespace mlir {
class ModuleOp;
} // namespace mlir
namespace triton {
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
int res;
driver::dispatch::cuDeviceGetAttribute(&res, attr, device);
return res;
}
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
std::string *ptxasPath);
// Translate TritonGPU IR to PTX code.
std::tuple<std::string, // ptx code
size_t, // PTX cc
int, // PTX version
std::string // ptxas path
>
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device);
} // namespace triton
#endif

View File

@@ -17,6 +17,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
MLIRGPUOps
MLIRGPUToNVVMTransforms
MLIRGPUTransforms
TritonAnalysis
TritonIR
TritonGPUIR
TritonGPUTransforms

View File

@@ -1 +1,2 @@
add_subdirectory(LLVMIR)
add_subdirectory(LLVMIR)
add_subdirectory(PTX)

View File

@@ -4,6 +4,8 @@
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
@@ -77,7 +79,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
}
std::unique_ptr<llvm::Module>
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
registerLLVMDialectTranslation(registry);
@@ -114,5 +116,26 @@ TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
return llvmModule;
}
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module) {
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
pm.addPass(createConvertTritonGPUToLLVMPass());
if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";
return nullptr;
}
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
return llvmir;
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,9 @@
add_mlir_translation_library(TritonPTX
PTXTranslation.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
TritonLLVMIR
)

View File

@@ -0,0 +1,41 @@
#include "triton/Target/PTX/PTXTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/llvm.h"
namespace triton {
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
std::string *ptxasPath) {
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
*cc = major * 10 + minor;
*ptxasPath = driver::path_to_ptxas(*version); // assign version
}
std::tuple<std::string, size_t, int, std::string>
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device) {
int cc;
int version;
std::string ptxasPath;
getCuCCAndVersionFromDevice(device, &cc, &version, &ptxasPath);
llvm::LLVMContext ctx;
auto llModule = mlir::triton::translateTritonGPUToLLVMIR(&ctx, module);
auto ptxCode = driver::llir_to_ptx(llModule.get(), cc, version);
return std::make_tuple(ptxCode, cc, version, ptxasPath);
}
} // namespace triton

View File

@@ -150,6 +150,7 @@ setup(
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6",
],
test_suite="tests",
extras_require={
"tests": [
"autopep8",

View File

@@ -1,6 +1,4 @@
// #include "triton/codegen/pass.h"
// #include "triton/codegen/target.h"
#include "triton/driver/error.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "mlir/IR/Builders.h"
@@ -17,14 +15,15 @@
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/raw_ostream.h"
#include "Python.h"
#include <Python.h>
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -40,21 +39,7 @@ namespace py = pybind11;
// namespace ir = triton::ir;
namespace drv = triton::driver;
/*****************************************************************************/
/* Python bindings for triton::driver */
/*****************************************************************************/
// information query
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
int res;
drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
return res;
}
template <hipDeviceAttribute_t attr> int hipGetInfo(hipDevice_t device) {
int res;
drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
return res;
}
using triton::cuGetInfo;
enum backend_t {
HOST,
@@ -100,18 +85,6 @@ void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
(CUstream)stream, nullptr, config);
}
void hip_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
uint64_t block_1, uint64_t block_2, void *args_ptr,
size_t args_size, int64_t shared_mem) {
void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
HIP_LAUNCH_PARAM_END};
drv::dispatch::hipModuleLaunchKernel(
(hipFunction_t)kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2,
shared_mem, (hipStream_t)stream, nullptr, config);
}
long pow2_divisor(long N) {
if (N % 16 == 0)
return 16;
@@ -381,8 +354,6 @@ void init_triton_runtime(py::module &&m) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
device);
if (backend == ROCM)
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
return -1;
});
@@ -432,9 +403,6 @@ void init_triton_runtime(py::module &&m) {
if (backend == CUDA)
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
block_2, args_ptr, args_size, shared_mem);
if (backend == ROCM)
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
block_1, block_2, args_ptr, args_size, shared_mem);
});
}
@@ -487,120 +455,6 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
// ROCM
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string &name,
asm_map_t &asm_map,
size_t n_shared_bytes,
uint64_t dev) {
py::bytes _assembly = asm_map["hsaco"];
std::string assembly = py::cast<std::string>(_assembly);
// HSA-CO -> hipModule
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
// Handle to the kernel
hipFunction_t fun;
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
// record asm
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
// ---------------------------------------
// Compile Triton-IR to assembly
// ---------------------------------------
// // CUDA
// std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string&
// name, ir::module &ir,
// uint64_t
// device, int
// num_warps, int
// num_stages,
// asm_map_t
// &asm_map){
// int n_shared_bytes;
// py::gil_scoped_release allow_threads;
// llvm::LLVMContext ctx;
// // device properties
// CUdevice dev = (CUdevice)device;
// size_t major =
// cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); size_t minor
// = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev); size_t cc =
// major*10 + minor; int version; std::string ptxas_path =
// drv::path_to_ptxas(version);
// // Triton-IR -> NVPTX LLVM-IR
// triton::codegen::nvidia_cu_target target(cc);
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc,
// num_warps, num_stages, n_shared_bytes); std::string tmp;
// llvm::raw_string_ostream llir(tmp);
// llir << *llvm;
// llir.flush();
// asm_map["llir"] = py::cast(tmp);
// // LLVM-IR -> PTX
// std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
// asm_map["ptx"] = py::cast(ptx);
// // PTX -> Binary
// std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
// if(!cubin.empty()){
// py::bytes bytes(cubin);
// asm_map["cubin"] = bytes;
// }
// return std::make_tuple(name, asm_map, n_shared_bytes);
// }
// // HIP
// std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string&
// name, ir::module &ir,
// uint64_t
// device, int
// num_warps,
// int
// num_stages,
// asm_map_t
// &asm_map){
// llvm::LLVMContext ctx;
// // Triton-IR -> NVPTX LLVM-IR
// triton::codegen::amd_cl_target target;
// int n_shared_bytes;
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70,
// num_warps, num_stages, n_shared_bytes); std::string tmp;
// llvm::raw_string_ostream llir(tmp);
// llir << *llvm;
// llir.flush();
// asm_map["llir"] = py::cast(tmp);
// // LLVM-IR -> HSA-CO
// std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
// asm_map["hsaco"] = py::cast(path);
// return std::make_tuple(name, asm_map, n_shared_bytes);
// }
// void init_triton_codegen(py::module &&m) {
// m.def(
// "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device,
// int num_warps, int num_stages) {
// std::string name = ir.get_function_list()[0]->get_name();
// // record asm as we generate
// asm_map_t asm_map;
// std::ostringstream ttir;
// ir.print(ttir);
// asm_map["ttir"] = py::cast(ttir.str());
// llvm::LLVMContext ctx;
// if(backend == CUDA)
// return cu_compile_ttir(name, ir, device, num_warps, num_stages,
// asm_map);
// if(backend == ROCM)
// return hip_compile_ttir(name, ir, device, num_warps, num_stages,
// asm_map);
// }, py::return_value_policy::take_ownership);
// m.def("load_binary", [](backend_t backend, const std::string& name,
// asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
// py::gil_scoped_release allow_threads;
// if(backend == CUDA)
// return cu_load_binary(name, asm_map, n_shared_bytes, dev);
// if(backend == ROCM)
// return hip_load_binary(name, asm_map, n_shared_bytes, dev);
// }, py::return_value_policy::take_ownership);
// }
/*****************************************************************************/
/* Python bindings for triton::ir */
/*****************************************************************************/
@@ -1655,9 +1509,45 @@ void init_triton_ir(py::module &&m) {
});
}
void init_translation(py::module &m) {
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
std::string str;
llvm::raw_string_ostream os(str);
llvmModule->print(os, nullptr);
os.flush();
return str;
});
m.def("translate_triton_gpu_to_ptx",
[](mlir::ModuleOp module, uint64_t device) -> std::string {
auto [ptxCode, cc, version, ptxasPath] =
triton::translateTritonGPUToPTX(module, device);
return ptxCode;
});
m.def("compile_ptx_to_cubin",
[](const std::string &ptxCode, uint64_t device) -> py::object {
py::gil_scoped_release allow_threads;
int version;
int cc;
std::string ptxasPath;
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
&ptxasPath);
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
py::bytes bytes(cubin);
return bytes;
});
}
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir")));
init_translation(subm);
}

0
python/tests/__init__.py Normal file
View File

View File

@@ -0,0 +1,23 @@
import torch
import triton
import triton.language as tl
# trigger the torch.device implicitly to ensure cuda context initialization
torch.zeros([10], device=torch.device('cuda'))
def test_empty_kernel_cubin_compile():
@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
device = torch.cuda.current_device()
cubin = triton.compile(kernel,
"*fp32,i32,i32",
device=device,
constants={"BLOCK": 256},
output="cubin")
print('cubin size:', len(cubin))
assert len(cubin) > 0

View File

@@ -791,13 +791,28 @@ def optimize_tritongpu_ir(mod, num_stages):
return mod
def make_ptx(mod):
# TODO
return mod
def make_ptx(mod, device):
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
:return: str
'''
return _triton.translate_triton_gpu_to_ptx(mod, device)
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
assert output in ["ttir", "ttgir", "ptx"]
def make_cubin(ptx, device):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param device: CUDA device
:return: str
'''
return _triton.compile_ptx_to_cubin(ptx, device)
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module = make_triton_ir(fn, signature, constants, attributes)
if output == "ttir":
@@ -807,7 +822,15 @@ def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
# ptx
assert device >= 0, "device should be provided."
ptx = make_ptx(module, device)
if output == "ptx":
return make_ptx(module)
return ptx
cubin = make_cubin(ptx, device)
if output == "cubin":
return cubin
assert False