diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 9763ca8b0..17427a481 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -49,7 +49,7 @@ jobs: cd python pip3 install -e '.[tests]' - - name: Run tests + - name: Run lit tests run: | cd python LIT_TEST_DIR="build/$(ls build)/test" @@ -57,3 +57,8 @@ jobs: echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1 fi lit -v "$LIT_TEST_DIR" + + - name: Run python tests + run: | + cd python/tests + # pytest diff --git a/CMakeLists.txt b/CMakeLists.txt index c8ec261f6..1ffb44152 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,7 +69,7 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "") # sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros else() set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}") - set(LLVM_LIBRARIES + set(LLVM_LIBRARIES libLLVMNVPTXCodeGen.a libLLVMNVPTXDesc.a libLLVMNVPTXInfo.a @@ -185,11 +185,18 @@ target_link_libraries(triton TritonTransforms TritonGPUTransforms TritonDriver + TritonLLVMIR + TritonPTX ${dialect_libs} ${conversion_libs} # optimizations MLIRPass MLIRTransforms + MLIRIR + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + MLIRExecutionEngine ) target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index ad5ec5f65..d0a766ac9 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -100,7 +100,6 @@ LogicalResult tritonTranslateMain(int argc, char **argv, llvm::InitLLVM y(argc, argv); registerAsmPrinterCLOptions(); - registerMLIRContextCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, toolName); @@ -118,7 +117,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv, } llvm::LLVMContext llvmContext; - auto llvmir = TranslateLLVMToLLVMIR(&llvmContext, *module); + auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module); if (!llvmir) { llvm::errs() << "Translate to LLVM IR failed"; } diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h index 0ec11d524..01411414b 100644 --- a/include/triton/Target/LLVMIR/LLVMIRTranslation.h +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -14,9 +14,14 @@ class ModuleOp; namespace mlir { namespace triton { +// Translate TritonGPU dialect to LLVMIR, return null if failed. +std::unique_ptr +translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, + mlir::ModuleOp module); + // Translate mlir LLVM dialect to LLVMIR, return null if failed. std::unique_ptr -TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module); +translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module); } // namespace triton } // namespace mlir diff --git a/include/triton/Target/PTX/PTXTranslation.h b/include/triton/Target/PTX/PTXTranslation.h new file mode 100644 index 000000000..45f8e5240 --- /dev/null +++ b/include/triton/Target/PTX/PTXTranslation.h @@ -0,0 +1,35 @@ +#ifndef TRITON_TARGET_PTXTRANSLATION_H +#define TRITON_TARGET_PTXTRANSLATION_H + +#include "triton/driver/dispatch.h" + +#include + +namespace mlir { + +class ModuleOp; + +} // namespace mlir + +namespace triton { + +template int cuGetInfo(CUdevice device) { + int res; + driver::dispatch::cuDeviceGetAttribute(&res, attr, device); + return res; +} + +void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version, + std::string *ptxasPath); + +// Translate TritonGPU IR to PTX code. +std::tuple +translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device); + +} // namespace triton + +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index ad971c9e6..945e847d4 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_conversion_library(TritonGPUToLLVM MLIRGPUOps MLIRGPUToNVVMTransforms MLIRGPUTransforms + TritonAnalysis TritonIR TritonGPUIR TritonGPUTransforms diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt index 88b49be6b..9b24f0ff2 100644 --- a/lib/Target/CMakeLists.txt +++ b/lib/Target/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(LLVMIR) \ No newline at end of file +add_subdirectory(LLVMIR) +add_subdirectory(PTX) diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index ed4931a8d..99d4710ca 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -4,6 +4,8 @@ #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" @@ -77,7 +79,7 @@ void extractNVVMMetadata(mlir::ModuleOp module, } std::unique_ptr -TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { +translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { auto context = module->getContext(); DialectRegistry registry; registerLLVMDialectTranslation(registry); @@ -114,5 +116,26 @@ TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { return llvmModule; } +std::unique_ptr +translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, + mlir::ModuleOp module) { + mlir::PassManager pm(module->getContext()); + applyPassManagerCLOptions(pm); + + pm.addPass(createConvertTritonGPUToLLVMPass()); + + if (failed(pm.run(module))) { + llvm::errs() << "Pass execution failed"; + return nullptr; + } + + auto llvmir = translateLLVMToLLVMIR(llvmContext, module); + if (!llvmir) { + llvm::errs() << "Translate to LLVM IR failed"; + } + + return llvmir; +} + } // namespace triton } // namespace mlir diff --git a/lib/Target/PTX/CMakeLists.txt b/lib/Target/PTX/CMakeLists.txt new file mode 100644 index 000000000..69aa5710c --- /dev/null +++ b/lib/Target/PTX/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_translation_library(TritonPTX + PTXTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + TritonLLVMIR + ) diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp new file mode 100644 index 000000000..b286e612a --- /dev/null +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -0,0 +1,41 @@ +#include "triton/Target/PTX/PTXTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "triton/Target/LLVMIR/LLVMIRTranslation.h" +#include "triton/driver/dispatch.h" +#include "triton/driver/llvm.h" + +namespace triton { + +void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version, + std::string *ptxasPath) { + CUdevice dev = (CUdevice)device; + size_t major = cuGetInfo(dev); + size_t minor = cuGetInfo(dev); + *cc = major * 10 + minor; + *ptxasPath = driver::path_to_ptxas(*version); // assign version +} + +std::tuple +translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device) { + int cc; + int version; + std::string ptxasPath; + getCuCCAndVersionFromDevice(device, &cc, &version, &ptxasPath); + + llvm::LLVMContext ctx; + auto llModule = mlir::triton::translateTritonGPUToLLVMIR(&ctx, module); + auto ptxCode = driver::llir_to_ptx(llModule.get(), cc, version); + return std::make_tuple(ptxCode, cc, version, ptxasPath); +} + +} // namespace triton diff --git a/python/setup.py b/python/setup.py index 472a59601..5c994f2d5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -150,6 +150,7 @@ setup( "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.6", ], + test_suite="tests", extras_require={ "tests": [ "autopep8", diff --git a/python/src/triton.cc b/python/src/triton.cc index 16c8e0b66..8b7d93e89 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,6 +1,4 @@ -// #include "triton/codegen/pass.h" -// #include "triton/codegen/target.h" -#include "triton/driver/error.h" +#include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "mlir/IR/Builders.h" @@ -17,14 +15,15 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Target/LLVMIR/LLVMIRTranslation.h" +#include "triton/Target/PTX/PTXTranslation.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" - #include "llvm/Support/raw_ostream.h" -#include "Python.h" +#include #include #include #include @@ -40,21 +39,7 @@ namespace py = pybind11; // namespace ir = triton::ir; namespace drv = triton::driver; -/*****************************************************************************/ -/* Python bindings for triton::driver */ -/*****************************************************************************/ -// information query -template int cuGetInfo(CUdevice device) { - int res; - drv::dispatch::cuDeviceGetAttribute(&res, attr, device); - return res; -} - -template int hipGetInfo(hipDevice_t device) { - int res; - drv::dispatch::hipDeviceGetAttribute(&res, attr, device); - return res; -} +using triton::cuGetInfo; enum backend_t { HOST, @@ -100,18 +85,6 @@ void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0, (CUstream)stream, nullptr, config); } -void hip_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0, - uint64_t grid_1, uint64_t grid_2, uint64_t block_0, - uint64_t block_1, uint64_t block_2, void *args_ptr, - size_t args_size, int64_t shared_mem) { - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size, - HIP_LAUNCH_PARAM_END}; - drv::dispatch::hipModuleLaunchKernel( - (hipFunction_t)kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, - shared_mem, (hipStream_t)stream, nullptr, config); -} - long pow2_divisor(long N) { if (N % 16 == 0) return 16; @@ -381,8 +354,6 @@ void init_triton_runtime(py::module &&m) { if (backend == CUDA) return cuGetInfo( device); - if (backend == ROCM) - return hipGetInfo(device); return -1; }); @@ -432,9 +403,6 @@ void init_triton_runtime(py::module &&m) { if (backend == CUDA) cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); - if (backend == ROCM) - hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, - block_1, block_2, args_ptr, args_size, shared_mem); }); } @@ -487,120 +455,6 @@ std::tuple cu_load_binary(const std::string &name, return std::make_tuple((uint64_t)mod, (uint64_t)fun); } -// ROCM -std::tuple hip_load_binary(const std::string &name, - asm_map_t &asm_map, - size_t n_shared_bytes, - uint64_t dev) { - py::bytes _assembly = asm_map["hsaco"]; - std::string assembly = py::cast(_assembly); - // HSA-CO -> hipModule - hipModule_t mod = drv::amdgpu_to_hipmodule(assembly); - // Handle to the kernel - hipFunction_t fun; - drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); - // record asm - return std::make_tuple((uint64_t)mod, (uint64_t)fun); -} - -// --------------------------------------- -// Compile Triton-IR to assembly -// --------------------------------------- - -// // CUDA -// std::tuple cu_compile_ttir(const std::string& -// name, ir::module &ir, -// uint64_t -// device, int -// num_warps, int -// num_stages, -// asm_map_t -// &asm_map){ - -// int n_shared_bytes; -// py::gil_scoped_release allow_threads; -// llvm::LLVMContext ctx; -// // device properties -// CUdevice dev = (CUdevice)device; -// size_t major = -// cuGetInfo(dev); size_t minor -// = cuGetInfo(dev); size_t cc = -// major*10 + minor; int version; std::string ptxas_path = -// drv::path_to_ptxas(version); -// // Triton-IR -> NVPTX LLVM-IR -// triton::codegen::nvidia_cu_target target(cc); -// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, -// num_warps, num_stages, n_shared_bytes); std::string tmp; -// llvm::raw_string_ostream llir(tmp); -// llir << *llvm; -// llir.flush(); -// asm_map["llir"] = py::cast(tmp); -// // LLVM-IR -> PTX -// std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); -// asm_map["ptx"] = py::cast(ptx); -// // PTX -> Binary -// std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); -// if(!cubin.empty()){ -// py::bytes bytes(cubin); -// asm_map["cubin"] = bytes; -// } -// return std::make_tuple(name, asm_map, n_shared_bytes); -// } - -// // HIP -// std::tuple hip_compile_ttir(const std::string& -// name, ir::module &ir, -// uint64_t -// device, int -// num_warps, -// int -// num_stages, -// asm_map_t -// &asm_map){ -// llvm::LLVMContext ctx; -// // Triton-IR -> NVPTX LLVM-IR -// triton::codegen::amd_cl_target target; -// int n_shared_bytes; -// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, -// num_warps, num_stages, n_shared_bytes); std::string tmp; -// llvm::raw_string_ostream llir(tmp); -// llir << *llvm; -// llir.flush(); -// asm_map["llir"] = py::cast(tmp); -// // LLVM-IR -> HSA-CO -// std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908"); -// asm_map["hsaco"] = py::cast(path); -// return std::make_tuple(name, asm_map, n_shared_bytes); -// } - -// void init_triton_codegen(py::module &&m) { -// m.def( -// "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, -// int num_warps, int num_stages) { -// std::string name = ir.get_function_list()[0]->get_name(); -// // record asm as we generate -// asm_map_t asm_map; -// std::ostringstream ttir; -// ir.print(ttir); -// asm_map["ttir"] = py::cast(ttir.str()); -// llvm::LLVMContext ctx; -// if(backend == CUDA) -// return cu_compile_ttir(name, ir, device, num_warps, num_stages, -// asm_map); -// if(backend == ROCM) -// return hip_compile_ttir(name, ir, device, num_warps, num_stages, -// asm_map); -// }, py::return_value_policy::take_ownership); -// m.def("load_binary", [](backend_t backend, const std::string& name, -// asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ -// py::gil_scoped_release allow_threads; -// if(backend == CUDA) -// return cu_load_binary(name, asm_map, n_shared_bytes, dev); -// if(backend == ROCM) -// return hip_load_binary(name, asm_map, n_shared_bytes, dev); -// }, py::return_value_policy::take_ownership); -// } - /*****************************************************************************/ /* Python bindings for triton::ir */ /*****************************************************************************/ @@ -1655,9 +1509,45 @@ void init_triton_ir(py::module &&m) { }); } +void init_translation(py::module &m) { + m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string { + llvm::LLVMContext llvmContext; + auto llvmModule = + ::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op); + + std::string str; + llvm::raw_string_ostream os(str); + llvmModule->print(os, nullptr); + os.flush(); + return str; + }); + + m.def("translate_triton_gpu_to_ptx", + [](mlir::ModuleOp module, uint64_t device) -> std::string { + auto [ptxCode, cc, version, ptxasPath] = + triton::translateTritonGPUToPTX(module, device); + return ptxCode; + }); + + m.def("compile_ptx_to_cubin", + [](const std::string &ptxCode, uint64_t device) -> py::object { + py::gil_scoped_release allow_threads; + int version; + int cc; + std::string ptxasPath; + triton::getCuCCAndVersionFromDevice(device, &cc, &version, + &ptxasPath); + + std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc); + py::bytes bytes(cubin); + return bytes; + }); +} + void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); // init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); + init_translation(subm); } diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/tests/test_compiler.py b/python/tests/test_compiler.py new file mode 100644 index 000000000..1ef75ccdc --- /dev/null +++ b/python/tests/test_compiler.py @@ -0,0 +1,23 @@ +import torch + +import triton +import triton.language as tl + +# trigger the torch.device implicitly to ensure cuda context initialization +torch.zeros([10], device=torch.device('cuda')) + + +def test_empty_kernel_cubin_compile(): + @triton.jit + def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): + pass + + device = torch.cuda.current_device() + cubin = triton.compile(kernel, + "*fp32,i32,i32", + device=device, + constants={"BLOCK": 256}, + output="cubin") + + print('cubin size:', len(cubin)) + assert len(cubin) > 0 diff --git a/python/triton/compiler.py b/python/triton/compiler.py index da85b2ade..321954a8f 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -791,13 +791,28 @@ def optimize_tritongpu_ir(mod, num_stages): return mod -def make_ptx(mod): - # TODO - return mod +def make_ptx(mod, device): + ''' + Translate TritonGPU module to PTX code. + :param mod: a TritonGPU dialect module + :return: str + ''' + return _triton.translate_triton_gpu_to_ptx(mod, device) -def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"): - assert output in ["ttir", "ttgir", "ptx"] +def make_cubin(ptx, device): + ''' + Compile TritonGPU module to cubin. + :param ptx: ptx code + :param device: CUDA device + :return: str + ''' + return _triton.compile_ptx_to_cubin(ptx, device) + + +def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"): + valid_outputs = ("ttir", "ttgir", "ptx", "cubin") + assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) # triton-ir module = make_triton_ir(fn, signature, constants, attributes) if output == "ttir": @@ -807,7 +822,15 @@ def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num module = optimize_tritongpu_ir(module, num_stages) if output == "ttgir": return module.str() - # ptx + + assert device >= 0, "device should be provided." + + ptx = make_ptx(module, device) if output == "ptx": - return make_ptx(module) + return ptx + + cubin = make_cubin(ptx, device) + if output == "cubin": + return cubin + assert False