[RUNTIME] Major code cleanup (#711)

This PR does the following:
- CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore.
- Refactoring driver/llvm.cc to split it between PTX codegen and python.
- By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore.
- `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
This commit is contained in:
Philippe Tillet
2022-09-26 16:38:06 -07:00
committed by GitHub
parent 8bb09f83ee
commit 1e91ed30d0
28 changed files with 509 additions and 31483 deletions

View File

@@ -1,7 +1,4 @@
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
@@ -10,6 +7,9 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -24,10 +24,14 @@
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/SourceMgr.h"
#include <Python.h>
#include <cctype>
#include <fstream>
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -40,10 +44,6 @@
#include <string>
namespace py = pybind11;
// namespace ir = triton::ir;
namespace drv = triton::driver;
using triton::cuGetInfo;
enum backend_t {
HOST,
@@ -51,306 +51,6 @@ enum backend_t {
ROCM,
};
void cu_enable_peer_access(uint64_t peer_ptr) {
CUcontext context;
drv::dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT,
peer_ptr);
try {
drv::dispatch::cuCtxEnablePeerAccess(context, 0);
} catch (drv::exception::cuda::peer_access_already_enabled) {
}
}
void host_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
uint64_t block_1, uint64_t block_2, void *args_ptr,
size_t args_size, int64_t shared_mem) {
throw std::runtime_error("unsupported");
// auto hst = kernel->module()->hst();
// hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
// char* params = new char[args_size];
// std::memcpy((void*)params, (void*)args, args_size);
// for(size_t i = 0; i < grid[0]; i++)
// for(size_t j = 0; j < grid[1]; j++)
// for(size_t k = 0; k < grid[2]; k++)
// hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn,
// (char**)params, int32_t(i), int32_t(j), int32_t(k)));
}
void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
uint64_t block_1, uint64_t block_2, void *args_ptr,
size_t args_size, int64_t shared_mem) {
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END};
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2, shared_mem,
(CUstream)stream, nullptr, config);
}
long pow2_divisor(long N) {
if (N % 16 == 0)
return 16;
if (N % 8 == 0)
return 8;
if (N % 4 == 0)
return 4;
if (N % 2 == 0)
return 2;
return 1;
}
// Returns something like "int16", whether dtype is a torch.dtype or
// triton.language.dtype.
std::string dtype_cache_key_part(const py::object &dtype) {
if (py::hasattr(dtype, "cache_key_part")) {
// Presumed to be a triton.language.dtype.
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
} else {
// Remove 'torch.' prefix from repr of torch.dtype.
py::object repr = py::repr(dtype);
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
const char *repr_ptr = (const char *)PyUnicode_1BYTE_DATA(repr.ptr());
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
throw std::logic_error("invalid dtype: " +
std::string(repr_ptr, repr_len));
}
return std::string(repr_ptr + 6, repr_len - 6);
}
}
size_t get_pointer_range_size(uint64_t addr) {
if (addr == 0)
return 0;
size_t size;
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE,
(CUdeviceptr)addr);
return size;
}
// Launch
void parse_args(py::list &args, py::list do_not_specialize,
const std::string &func_key, py::list &arg_names,
std::string &cache_key, std::string &params,
size_t &params_size, py::dict constants, int num_warps,
int num_stages) {
size_t len = PyList_Size(args.ptr());
params.reserve(8 * len); // 8 max bytes by argument
char *params_ptr = &params[0];
cache_key = func_key;
cache_key += "-" + std::to_string(num_warps);
cache_key += "-" + std::to_string(num_stages);
cache_key += "-";
for (int i = 0; i < len; i++) {
cache_key += "_";
py::int_ py_i = py::int_(i);
bool specialize = !do_not_specialize.contains(py_i);
py::object arg = args[i];
auto arg_ptr = arg.ptr();
// argument is `long`
if (PyLong_Check(arg_ptr)) {
int overflow;
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
// values equal to 1 are specialized
if (specialize && (value == 1)) {
cache_key += "1";
continue;
}
// int32, uint32, int64, and uint64 have different kernels
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
cache_key += "int32";
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow && 0x8000'0000LL <= value &&
value <= 0xFFFF'FFFFLL) {
cache_key += "uint32";
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
cache_key += "int64";
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " +
std::string(py::str(arg)));
}
cache_key += "uint64";
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
if (!specialize)
continue;
// values divisible by small powers of 2 are specialized
cache_key += "[multipleof(";
cache_key += std::to_string(pow2_divisor(value));
cache_key += ")]";
continue;
}
// argument is `float`
if (PyFloat_Check(arg_ptr)) {
cache_key += "float32";
float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
continue;
}
// argument is `bool`
if (PyBool_Check(arg_ptr)) {
cache_key += "bool";
bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1);
params_ptr += 1;
continue;
}
// argument is tensor
if (py::hasattr(arg, "data_ptr")) {
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// update cache key
cache_key += dtype_cache_key_part(arg.attr("dtype"));
cache_key += "*";
cache_key += "[multipleof(";
size_t range_size = get_pointer_range_size(value);
cache_key += std::to_string(
std::min(pow2_divisor(value), pow2_divisor(range_size)));
cache_key += ")]";
continue;
}
// argument is `constexpr`
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
py::object repr = py::repr(value);
const char *start = (const char *)PyUnicode_1BYTE_DATA(repr.ptr());
size_t len = PyUnicode_GET_LENGTH(repr.ptr());
cache_key += std::string(start, len);
continue;
}
std::string ty_str =
arg.attr("__class__").attr("__name__").cast<std::string>();
if (ty_str == "NoneType") {
cache_key += "None";
continue;
}
std::string err_msg = "Received type '" + ty_str + "' for argument " +
std::to_string(i) + "." +
" Only int, float, bool, torch.Tensor, and "
"triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
void parse_args(py::list &args, py::list &arg_names, std::string &params,
size_t &params_size, py::dict constants) {
size_t len = PyList_Size(args.ptr());
params.reserve(8 * len); // 8 max bytes by argument
char *params_ptr = params.data();
for (int i = 0; i < len; i++) {
py::object arg = args[i];
auto arg_ptr = arg.ptr();
if (PyLong_Check(arg_ptr)) {
int overflow{};
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow && 0x8000'0000LL <= value &&
value <= 0xFFFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " +
std::string(py::str(arg)));
}
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
continue;
}
if (PyFloat_Check(arg_ptr)) {
float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
continue;
}
// argument is `bool`
if (PyBool_Check(arg_ptr)) {
bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1);
params_ptr += 1;
continue;
}
// argument is torch.tensor, get data_ptr as memory address
if (py::hasattr(arg, "data_ptr")) {
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// update cache key
continue;
}
// argument is `constexpr`
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
continue;
}
// argument is `LoadedBinary`
if (py::hasattr(arg, "get_sass")) {
// Do nothing, just a placeholder here to indicate validity.
continue;
}
std::string ty_str =
arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " +
std::to_string(i) + "." +
" Only int, float, bool, torch.Tensor, and "
"triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
void init_triton_runtime(py::module &&m) {
// wrap backend_t
py::enum_<backend_t>(m, "backend")
@@ -358,192 +58,8 @@ void init_triton_runtime(py::module &&m) {
.value("CUDA", CUDA)
// .value("ROCM", ROCM)
.export_values();
// enable peer-to-peer
m.def("enable_peer_access", [](backend_t backend, uint64_t peer_ptr) {
if (backend != CUDA)
throw std::runtime_error("P2P only supported on CUDA devices!");
cu_enable_peer_access(peer_ptr);
});
// get range size for the given pointer
m.def("get_pointer_range_size", &get_pointer_range_size);
// cache key
m.def("launch", [](py::list args, py::list do_not_specialize,
const std::string &func_key, py::list &arg_names,
py::object device, py::int_ stream, py::dict bin_cache,
py::int_ num_warps, py::int_ num_stages,
py::function add_to_cache, py::object grid) {
// parse arguments to compute cache key, compile-time constants and packed
// kernel arguments
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
std::string cache_key;
std::string params;
size_t params_size;
py::dict constants;
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
params_size, constants, _num_warps, _num_stages);
// get cached binary
py::str key(cache_key);
py::bool_ noop = false;
if (!bin_cache.contains(key)) {
noop = add_to_cache(key, args, device, num_warps, num_stages);
}
if (noop)
return (py::object)py::none();
py::object bin = bin_cache[key];
// get grid
py::sequence seq;
if (!PySequence_Check(grid.ptr()))
seq = grid(constants);
else
seq = grid;
int size = seq.size();
int grid_0 = py::cast<int>(seq[0]);
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
// enqueue
uint64_t kernel = py::cast<uint64_t>(bin.attr("kernel"));
uint64_t shared_mem = py::cast<uint64_t>(bin.attr("shared_mem"));
// actually launch
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END};
uint64_t _stream = PyLong_AsLong(stream.ptr());
if (grid_0 * grid_1 * grid_2 > 0) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps * 32, 1, 1, shared_mem,
(CUstream)_stream, nullptr, config);
}
return bin;
});
m.def("cc", [](backend_t backend, uint64_t device) -> int {
if (backend == CUDA) {
CUdevice dev = (CUdevice)device;
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
return major * 10 + minor;
}
return -1;
});
m.def("launch_binary", [](py::object binary, py::list args,
py::list do_not_specialize, py::list arg_names,
py::int_ stream, py::int_ num_warps,
py::int_ num_stages, py::object grid) {
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
// get grid
py::sequence seq;
py::dict constants;
std::string params;
size_t params_size{};
parse_args(args, arg_names, params, params_size, constants);
if (!PySequence_Check(grid.ptr()))
seq = grid(constants);
else
seq = grid;
int size = seq.size();
int grid_0 = py::cast<int>(seq[0]);
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
// actually launch
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END};
uint64_t _stream = PyLong_AsLong(stream.ptr());
const int numGrids = grid_0 * grid_1 * grid_2;
if (numGrids) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps * 32, 1, 1, shared_mem,
(CUstream)_stream, nullptr, config);
}
return binary;
});
// query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
return 0;
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
device);
return -1;
});
// query DRAM & L2 cache
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
return -1;
});
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
return -1;
});
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
return -1;
});
// query clock rate (in kilohertz)
m.def("clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
return -1;
});
m.def("num_sm", [](backend_t backend, uint64_t device) {
if (backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
return -1;
});
// enqueue
m.def("enqueue",
[](backend_t backend, uint64_t stream, uint64_t kernel, uint64_t grid_0,
uint64_t grid_1, uint64_t grid_2, uint64_t block_0, uint64_t block_1,
uint64_t block_2, const std::string &args, int64_t shared_mem) {
void *args_ptr = (void *)args.data();
size_t args_size = args.size();
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
if (backend == HOST)
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
block_1, block_2, args_ptr, args_size, shared_mem);
if (backend == CUDA)
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
block_2, args_ptr, args_size, shared_mem);
});
}
/*****************************************************************************/
/* Python bindings for triton::codegen */
/*****************************************************************************/
typedef std::map<std::string, py::object> asm_map_t;
/*****************************************************************************/
/* Python bindings for triton::ir */
/*****************************************************************************/
@@ -783,6 +299,38 @@ void init_triton_ir(py::module &&m) {
return self.lookupSymbol<mlir::FuncOp>(funcName);
});
m.def(
"parse_mlir_module",
[](const std::string &inputFilename, mlir::MLIRContext &context) {
// open file
std::string errorMessage;
auto input = mlir::openInputFile(inputFilename, &errorMessage);
if (!input)
throw std::runtime_error(errorMessage);
// initialize registry
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
context.allowUnregisteredDialects();
// parse module
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
mlir::OwningOpRef<mlir::ModuleOp> module(
mlir::parseSourceFile(sourceMgr, &context));
if (!module)
throw std::runtime_error("Parse MLIR file failed.");
return module->clone();
},
ret::take_ownership);
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
// .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr);
@@ -1643,84 +1191,86 @@ void init_triton_ir(py::module &&m) {
}
void init_triton_translation(py::module &m) {
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
std::string str;
llvm::raw_string_ostream os(str);
llvmModule->print(os, nullptr);
os.flush();
return str;
using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
auto pass = std::make_unique<mlir::Allocation>(module);
return pass->getSharedMemorySize();
});
m.def("translate_triton_gpu_to_ptx",
[](mlir::ModuleOp module, uint64_t device)
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
auto [ptxCode, cc, version, ptxasPath] =
triton::translateTritonGPUToPTX(module, device);
m.def(
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op) {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
mlir::PassManager pm(module->getContext());
auto pass = std::make_unique<mlir::Allocation>(module);
size_t size = pass->getSharedMemorySize();
std::string str;
llvm::raw_string_ostream os(str);
llvmModule->print(os, nullptr);
os.flush();
return str;
},
ret::take_ownership);
return std::make_tuple(ptxCode, size);
});
m.def(
"translate_llvmir_to_ptx",
[](const std::string llvmIR, int capability, int version) -> std::string {
// create LLVM module from C++
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
// translate module to PTX
auto ptxCode =
triton::translateLLVMIRToPTX(*module, capability, version);
return ptxCode;
},
ret::take_ownership);
m.def("compile_ptx_to_cubin",
[](const std::string &ptxCode, uint64_t device) -> py::object {
[](const std::string &ptxCode, const std::string &ptxasPath,
int capability) -> py::object {
py::gil_scoped_release allow_threads;
int version;
int cc;
std::string ptxasPath;
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
&ptxasPath);
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
std::tmpnam(_fsrc);
std::tmpnam(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptxCode << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
" " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if (err != 0) {
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log);
}
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
py::bytes bytes(cubin);
return bytes;
});
m.def(
"load_binary",
[](const std::string &name, const std::string &data,
size_t n_shared_bytes, uint64_t device) {
py::gil_scoped_release allow_threads;
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
fun);
drv::dispatch::cuFuncGetAttribute(
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(
&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
if (n_shared_bytes > 49152 && shared_optin > 49152) {
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
drv::dispatch::cuDeviceGetAttribute(
&shared_total,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
drv::dispatch::cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncSetAttribute(
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
(uint64_t)n_spills);
},
py::return_value_policy::take_ownership);
}
void init_triton(py::module &m) {