[RUNTIME] Major code cleanup (#711)
This PR does the following: - CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore. - Refactoring driver/llvm.cc to split it between PTX codegen and python. - By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore. - `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
@@ -10,6 +7,9 @@
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
@@ -24,10 +24,14 @@
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
#include <Python.h>
|
||||
#include <cctype>
|
||||
#include <fstream>
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -40,10 +44,6 @@
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
// namespace ir = triton::ir;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
using triton::cuGetInfo;
|
||||
|
||||
enum backend_t {
|
||||
HOST,
|
||||
@@ -51,306 +51,6 @@ enum backend_t {
|
||||
ROCM,
|
||||
};
|
||||
|
||||
void cu_enable_peer_access(uint64_t peer_ptr) {
|
||||
CUcontext context;
|
||||
drv::dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT,
|
||||
peer_ptr);
|
||||
try {
|
||||
drv::dispatch::cuCtxEnablePeerAccess(context, 0);
|
||||
} catch (drv::exception::cuda::peer_access_already_enabled) {
|
||||
}
|
||||
}
|
||||
|
||||
void host_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
|
||||
uint64_t block_1, uint64_t block_2, void *args_ptr,
|
||||
size_t args_size, int64_t shared_mem) {
|
||||
throw std::runtime_error("unsupported");
|
||||
// auto hst = kernel->module()->hst();
|
||||
// hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
|
||||
// char* params = new char[args_size];
|
||||
// std::memcpy((void*)params, (void*)args, args_size);
|
||||
// for(size_t i = 0; i < grid[0]; i++)
|
||||
// for(size_t j = 0; j < grid[1]; j++)
|
||||
// for(size_t k = 0; k < grid[2]; k++)
|
||||
// hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn,
|
||||
// (char**)params, int32_t(i), int32_t(j), int32_t(k)));
|
||||
}
|
||||
|
||||
void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
|
||||
uint64_t block_1, uint64_t block_2, void *args_ptr,
|
||||
size_t args_size, int64_t shared_mem) {
|
||||
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END};
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
block_0, block_1, block_2, shared_mem,
|
||||
(CUstream)stream, nullptr, config);
|
||||
}
|
||||
|
||||
long pow2_divisor(long N) {
|
||||
if (N % 16 == 0)
|
||||
return 16;
|
||||
if (N % 8 == 0)
|
||||
return 8;
|
||||
if (N % 4 == 0)
|
||||
return 4;
|
||||
if (N % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||
// triton.language.dtype.
|
||||
std::string dtype_cache_key_part(const py::object &dtype) {
|
||||
if (py::hasattr(dtype, "cache_key_part")) {
|
||||
// Presumed to be a triton.language.dtype.
|
||||
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
|
||||
} else {
|
||||
// Remove 'torch.' prefix from repr of torch.dtype.
|
||||
py::object repr = py::repr(dtype);
|
||||
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
|
||||
const char *repr_ptr = (const char *)PyUnicode_1BYTE_DATA(repr.ptr());
|
||||
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
|
||||
throw std::logic_error("invalid dtype: " +
|
||||
std::string(repr_ptr, repr_len));
|
||||
}
|
||||
return std::string(repr_ptr + 6, repr_len - 6);
|
||||
}
|
||||
}
|
||||
|
||||
size_t get_pointer_range_size(uint64_t addr) {
|
||||
if (addr == 0)
|
||||
return 0;
|
||||
size_t size;
|
||||
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE,
|
||||
(CUdeviceptr)addr);
|
||||
return size;
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::list &args, py::list do_not_specialize,
|
||||
const std::string &func_key, py::list &arg_names,
|
||||
std::string &cache_key, std::string ¶ms,
|
||||
size_t ¶ms_size, py::dict constants, int num_warps,
|
||||
int num_stages) {
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
params.reserve(8 * len); // 8 max bytes by argument
|
||||
char *params_ptr = ¶ms[0];
|
||||
cache_key = func_key;
|
||||
cache_key += "-" + std::to_string(num_warps);
|
||||
cache_key += "-" + std::to_string(num_stages);
|
||||
cache_key += "-";
|
||||
for (int i = 0; i < len; i++) {
|
||||
cache_key += "_";
|
||||
py::int_ py_i = py::int_(i);
|
||||
bool specialize = !do_not_specialize.contains(py_i);
|
||||
py::object arg = args[i];
|
||||
auto arg_ptr = arg.ptr();
|
||||
|
||||
// argument is `long`
|
||||
if (PyLong_Check(arg_ptr)) {
|
||||
int overflow;
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
// values equal to 1 are specialized
|
||||
if (specialize && (value == 1)) {
|
||||
cache_key += "1";
|
||||
continue;
|
||||
}
|
||||
// int32, uint32, int64, and uint64 have different kernels
|
||||
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
|
||||
cache_key += "int32";
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow && 0x8000'0000LL <= value &&
|
||||
value <= 0xFFFF'FFFFLL) {
|
||||
cache_key += "uint32";
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow) {
|
||||
cache_key += "int64";
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
} else {
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::logic_error("An error occurred?");
|
||||
}
|
||||
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::runtime_error("integer overflow in argument: " +
|
||||
std::string(py::str(arg)));
|
||||
}
|
||||
cache_key += "uint64";
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &unsigned_value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
if (!specialize)
|
||||
continue;
|
||||
// values divisible by small powers of 2 are specialized
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += std::to_string(pow2_divisor(value));
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
// argument is `float`
|
||||
if (PyFloat_Check(arg_ptr)) {
|
||||
cache_key += "float32";
|
||||
float value = PyFloat_AsDouble(arg_ptr);
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
continue;
|
||||
}
|
||||
// argument is `bool`
|
||||
if (PyBool_Check(arg_ptr)) {
|
||||
cache_key += "bool";
|
||||
bool value = arg_ptr == Py_True ? true : false;
|
||||
std::memcpy(params_ptr, &value, 1);
|
||||
params_ptr += 1;
|
||||
continue;
|
||||
}
|
||||
// argument is tensor
|
||||
if (py::hasattr(arg, "data_ptr")) {
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
// copy param
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
// update cache key
|
||||
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
||||
cache_key += "*";
|
||||
cache_key += "[multipleof(";
|
||||
size_t range_size = get_pointer_range_size(value);
|
||||
cache_key += std::to_string(
|
||||
std::min(pow2_divisor(value), pow2_divisor(range_size)));
|
||||
cache_key += ")]";
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
if (py::hasattr(arg, "value")) {
|
||||
py::object value = arg.attr("value");
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
py::object repr = py::repr(value);
|
||||
const char *start = (const char *)PyUnicode_1BYTE_DATA(repr.ptr());
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr());
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
std::string ty_str =
|
||||
arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
if (ty_str == "NoneType") {
|
||||
cache_key += "None";
|
||||
continue;
|
||||
}
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " +
|
||||
std::to_string(i) + "." +
|
||||
" Only int, float, bool, torch.Tensor, and "
|
||||
"triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
void parse_args(py::list &args, py::list &arg_names, std::string ¶ms,
|
||||
size_t ¶ms_size, py::dict constants) {
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
params.reserve(8 * len); // 8 max bytes by argument
|
||||
char *params_ptr = params.data();
|
||||
for (int i = 0; i < len; i++) {
|
||||
py::object arg = args[i];
|
||||
auto arg_ptr = arg.ptr();
|
||||
|
||||
if (PyLong_Check(arg_ptr)) {
|
||||
int overflow{};
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
|
||||
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow && 0x8000'0000LL <= value &&
|
||||
value <= 0xFFFF'FFFFLL) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
} else {
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::logic_error("An error occurred?");
|
||||
}
|
||||
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::runtime_error("integer overflow in argument: " +
|
||||
std::string(py::str(arg)));
|
||||
}
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &unsigned_value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (PyFloat_Check(arg_ptr)) {
|
||||
float value = PyFloat_AsDouble(arg_ptr);
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// argument is `bool`
|
||||
if (PyBool_Check(arg_ptr)) {
|
||||
bool value = arg_ptr == Py_True ? true : false;
|
||||
std::memcpy(params_ptr, &value, 1);
|
||||
params_ptr += 1;
|
||||
continue;
|
||||
}
|
||||
// argument is torch.tensor, get data_ptr as memory address
|
||||
if (py::hasattr(arg, "data_ptr")) {
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
// copy param
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
// update cache key
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
if (py::hasattr(arg, "value")) {
|
||||
py::object value = arg.attr("value");
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
continue;
|
||||
}
|
||||
// argument is `LoadedBinary`
|
||||
if (py::hasattr(arg, "get_sass")) {
|
||||
// Do nothing, just a placeholder here to indicate validity.
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string ty_str =
|
||||
arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " +
|
||||
std::to_string(i) + "." +
|
||||
" Only int, float, bool, torch.Tensor, and "
|
||||
"triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
// wrap backend_t
|
||||
py::enum_<backend_t>(m, "backend")
|
||||
@@ -358,192 +58,8 @@ void init_triton_runtime(py::module &&m) {
|
||||
.value("CUDA", CUDA)
|
||||
// .value("ROCM", ROCM)
|
||||
.export_values();
|
||||
|
||||
// enable peer-to-peer
|
||||
m.def("enable_peer_access", [](backend_t backend, uint64_t peer_ptr) {
|
||||
if (backend != CUDA)
|
||||
throw std::runtime_error("P2P only supported on CUDA devices!");
|
||||
cu_enable_peer_access(peer_ptr);
|
||||
});
|
||||
|
||||
// get range size for the given pointer
|
||||
m.def("get_pointer_range_size", &get_pointer_range_size);
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize,
|
||||
const std::string &func_key, py::list &arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache,
|
||||
py::int_ num_warps, py::int_ num_stages,
|
||||
py::function add_to_cache, py::object grid) {
|
||||
// parse arguments to compute cache key, compile-time constants and packed
|
||||
// kernel arguments
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
std::string cache_key;
|
||||
std::string params;
|
||||
size_t params_size;
|
||||
py::dict constants;
|
||||
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
|
||||
params_size, constants, _num_warps, _num_stages);
|
||||
|
||||
// get cached binary
|
||||
py::str key(cache_key);
|
||||
py::bool_ noop = false;
|
||||
if (!bin_cache.contains(key)) {
|
||||
noop = add_to_cache(key, args, device, num_warps, num_stages);
|
||||
}
|
||||
if (noop)
|
||||
return (py::object)py::none();
|
||||
py::object bin = bin_cache[key];
|
||||
|
||||
// get grid
|
||||
py::sequence seq;
|
||||
if (!PySequence_Check(grid.ptr()))
|
||||
seq = grid(constants);
|
||||
else
|
||||
seq = grid;
|
||||
int size = seq.size();
|
||||
int grid_0 = py::cast<int>(seq[0]);
|
||||
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
|
||||
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
|
||||
|
||||
// enqueue
|
||||
uint64_t kernel = py::cast<uint64_t>(bin.attr("kernel"));
|
||||
uint64_t shared_mem = py::cast<uint64_t>(bin.attr("shared_mem"));
|
||||
|
||||
// actually launch
|
||||
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size,
|
||||
CU_LAUNCH_PARAM_END};
|
||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||
if (grid_0 * grid_1 * grid_2 > 0) {
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
py::gil_scoped_release allow_threads;
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps * 32, 1, 1, shared_mem,
|
||||
(CUstream)_stream, nullptr, config);
|
||||
}
|
||||
return bin;
|
||||
});
|
||||
|
||||
m.def("cc", [](backend_t backend, uint64_t device) -> int {
|
||||
if (backend == CUDA) {
|
||||
CUdevice dev = (CUdevice)device;
|
||||
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||
return major * 10 + minor;
|
||||
}
|
||||
return -1;
|
||||
});
|
||||
|
||||
m.def("launch_binary", [](py::object binary, py::list args,
|
||||
py::list do_not_specialize, py::list arg_names,
|
||||
py::int_ stream, py::int_ num_warps,
|
||||
py::int_ num_stages, py::object grid) {
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
|
||||
// get grid
|
||||
py::sequence seq;
|
||||
py::dict constants;
|
||||
std::string params;
|
||||
size_t params_size{};
|
||||
parse_args(args, arg_names, params, params_size, constants);
|
||||
if (!PySequence_Check(grid.ptr()))
|
||||
seq = grid(constants);
|
||||
else
|
||||
seq = grid;
|
||||
|
||||
int size = seq.size();
|
||||
int grid_0 = py::cast<int>(seq[0]);
|
||||
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
|
||||
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
|
||||
|
||||
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
|
||||
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
|
||||
|
||||
// actually launch
|
||||
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size,
|
||||
CU_LAUNCH_PARAM_END};
|
||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||
const int numGrids = grid_0 * grid_1 * grid_2;
|
||||
if (numGrids) {
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
py::gil_scoped_release allow_threads;
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps * 32, 1, 1, shared_mem,
|
||||
(CUstream)_stream, nullptr, config);
|
||||
}
|
||||
return binary;
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
return 0;
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
|
||||
device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query DRAM & L2 cache
|
||||
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
|
||||
return -1;
|
||||
});
|
||||
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
|
||||
return -1;
|
||||
});
|
||||
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query clock rate (in kilohertz)
|
||||
m.def("clock_rate", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
m.def("num_sm", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// enqueue
|
||||
m.def("enqueue",
|
||||
[](backend_t backend, uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
uint64_t grid_1, uint64_t grid_2, uint64_t block_0, uint64_t block_1,
|
||||
uint64_t block_2, const std::string &args, int64_t shared_mem) {
|
||||
void *args_ptr = (void *)args.data();
|
||||
size_t args_size = args.size();
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
py::gil_scoped_release allow_threads;
|
||||
if (backend == HOST)
|
||||
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
|
||||
block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
if (backend == CUDA)
|
||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
|
||||
block_2, args_ptr, args_size, shared_mem);
|
||||
});
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::codegen */
|
||||
/*****************************************************************************/
|
||||
typedef std::map<std::string, py::object> asm_map_t;
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
/*****************************************************************************/
|
||||
@@ -783,6 +299,38 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.lookupSymbol<mlir::FuncOp>(funcName);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"parse_mlir_module",
|
||||
[](const std::string &inputFilename, mlir::MLIRContext &context) {
|
||||
// open file
|
||||
std::string errorMessage;
|
||||
auto input = mlir::openInputFile(inputFilename, &errorMessage);
|
||||
if (!input)
|
||||
throw std::runtime_error(errorMessage);
|
||||
|
||||
// initialize registry
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
|
||||
mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
|
||||
|
||||
context.appendDialectRegistry(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
context.allowUnregisteredDialects();
|
||||
|
||||
// parse module
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(input), llvm::SMLoc());
|
||||
mlir::OwningOpRef<mlir::ModuleOp> module(
|
||||
mlir::parseSourceFile(sourceMgr, &context));
|
||||
if (!module)
|
||||
throw std::runtime_error("Parse MLIR file failed.");
|
||||
|
||||
return module->clone();
|
||||
},
|
||||
ret::take_ownership);
|
||||
|
||||
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
|
||||
// .def_property_readonly("attrs", &ir::function::attrs)
|
||||
// .def("add_attr", &ir::function::add_attr);
|
||||
@@ -1643,84 +1191,86 @@ void init_triton_ir(py::module &&m) {
|
||||
}
|
||||
|
||||
void init_triton_translation(py::module &m) {
|
||||
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
llvmModule->print(os, nullptr);
|
||||
os.flush();
|
||||
return str;
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
||||
auto pass = std::make_unique<mlir::Allocation>(module);
|
||||
return pass->getSharedMemorySize();
|
||||
});
|
||||
|
||||
m.def("translate_triton_gpu_to_ptx",
|
||||
[](mlir::ModuleOp module, uint64_t device)
|
||||
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
|
||||
auto [ptxCode, cc, version, ptxasPath] =
|
||||
triton::translateTritonGPUToPTX(module, device);
|
||||
m.def(
|
||||
"translate_triton_gpu_to_llvmir",
|
||||
[](mlir::ModuleOp op) {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
auto pass = std::make_unique<mlir::Allocation>(module);
|
||||
size_t size = pass->getSharedMemorySize();
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
llvmModule->print(os, nullptr);
|
||||
os.flush();
|
||||
return str;
|
||||
},
|
||||
ret::take_ownership);
|
||||
|
||||
return std::make_tuple(ptxCode, size);
|
||||
});
|
||||
m.def(
|
||||
"translate_llvmir_to_ptx",
|
||||
[](const std::string llvmIR, int capability, int version) -> std::string {
|
||||
// create LLVM module from C++
|
||||
llvm::LLVMContext context;
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer =
|
||||
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
triton::translateLLVMIRToPTX(*module, capability, version);
|
||||
return ptxCode;
|
||||
},
|
||||
ret::take_ownership);
|
||||
|
||||
m.def("compile_ptx_to_cubin",
|
||||
[](const std::string &ptxCode, uint64_t device) -> py::object {
|
||||
[](const std::string &ptxCode, const std::string &ptxasPath,
|
||||
int capability) -> py::object {
|
||||
py::gil_scoped_release allow_threads;
|
||||
int version;
|
||||
int cc;
|
||||
std::string ptxasPath;
|
||||
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
|
||||
&ptxasPath);
|
||||
|
||||
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
|
||||
// compile ptx with ptxas
|
||||
char _fsrc[L_tmpnam];
|
||||
char _flog[L_tmpnam];
|
||||
std::tmpnam(_fsrc);
|
||||
std::tmpnam(_flog);
|
||||
std::string fsrc = _fsrc;
|
||||
std::string flog = _flog;
|
||||
std::string fbin = fsrc + ".o";
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(fsrc);
|
||||
ofs << ptxCode << std::endl;
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
|
||||
" " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
err = system(cmd.c_str());
|
||||
if (err != 0) {
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
unlink(_fsrc);
|
||||
unlink(_flog);
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
||||
log);
|
||||
}
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
unlink(_fsrc);
|
||||
unlink(_flog);
|
||||
unlink(_fbin);
|
||||
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"load_binary",
|
||||
[](const std::string &name, const std::string &data,
|
||||
size_t n_shared_bytes, uint64_t device) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
|
||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
|
||||
fun);
|
||||
drv::dispatch::cuFuncGetAttribute(
|
||||
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_optin,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
|
||||
if (n_shared_bytes > 49152 && shared_optin > 49152) {
|
||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
int shared_total, shared_static;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_total,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
|
||||
drv::dispatch::cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncSetAttribute(
|
||||
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
|
||||
(uint64_t)n_spills);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
|
@@ -7,6 +7,7 @@ import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -843,7 +844,11 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
|
||||
def make_llvm_ir(mod):
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||
|
||||
|
||||
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
@@ -851,17 +856,17 @@ def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
|
||||
- PTX code
|
||||
- shared memory alloaction size
|
||||
'''
|
||||
return _triton.translate_triton_gpu_to_ptx(mod, device)
|
||||
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
||||
|
||||
|
||||
def make_cubin(ptx, device):
|
||||
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param device: CUDA device
|
||||
:return: str
|
||||
'''
|
||||
return _triton.compile_ptx_to_cubin(ptx, device)
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
||||
|
||||
|
||||
def ptx_get_kernel_name(ptx: str) -> str:
|
||||
@@ -877,6 +882,46 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
Get the highest PTX version supported by the current CUDA driver.
|
||||
'''
|
||||
assert isinstance(cuda_version, str)
|
||||
major, minor = map(int, cuda_version.split('.'))
|
||||
version = major * 1000 + minor * 10
|
||||
if version >= 11040:
|
||||
return 74
|
||||
if version >= 11030:
|
||||
return 73
|
||||
if version >= 11020:
|
||||
return 72
|
||||
if version >= 11010:
|
||||
return 71
|
||||
if version >= 11000:
|
||||
return 70
|
||||
if version >= 10020:
|
||||
return 65
|
||||
if version >= 10010:
|
||||
return 64
|
||||
if version >= 10000:
|
||||
return 63
|
||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||
|
||||
|
||||
def path_to_ptxas():
|
||||
prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", "/usr/local/cuda/"]
|
||||
for prefix in prefixes:
|
||||
ptxas = os.path.join(prefix, "bin", "ptxas")
|
||||
if os.path.exists(ptxas):
|
||||
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
|
||||
if result is not None:
|
||||
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
||||
if version is not None:
|
||||
return ptxas, version.group(1)
|
||||
raise RuntimeError("Cannot find ptxas")
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
@@ -895,17 +940,24 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
|
||||
# llvm-ir
|
||||
llvm_ir = make_llvm_ir(module)
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
ptx, shem_size = make_ptx(module, device)
|
||||
ptxas, cuda_version = path_to_ptxas()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
|
||||
shem_size = _triton.get_shared_memory_size(module)
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
if output == "ptx":
|
||||
return ptx, shem_size, kernel_name
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
cubin = make_cubin(ptx, ptxas, compute_capability)
|
||||
if output == "cubin":
|
||||
return cubin, ptx, shem_size, kernel_name
|
||||
|
||||
@@ -980,6 +1032,7 @@ def generate_launcher(identifier, constants, signature):
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
@@ -993,13 +1046,16 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
if (PyLong_Check(obj)) {{
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||
@@ -1021,6 +1077,7 @@ static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return (CUdeviceptr)0;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
@@ -1039,10 +1096,12 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"launcher\",
|
||||
@@ -1050,6 +1109,7 @@ static struct PyModuleDef ModuleDef = {{
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
@@ -1251,7 +1311,10 @@ class CompiledKernel:
|
||||
self.asm["ptx"] = f.read()
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
cuda_utils = CudaUtils()
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
@@ -1261,3 +1324,118 @@ class CompiledKernel:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||
return
|
||||
|
||||
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def _generate_src(self):
|
||||
return """
|
||||
#include <cuda.h>
|
||||
|
||||
#include \"cuda.h\"
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{
|
||||
if (code != CUDA_SUCCESS)
|
||||
{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
|
||||
|
||||
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
||||
const char* name;
|
||||
const char* data;
|
||||
Py_ssize_t data_size;
|
||||
int shared;
|
||||
int device;
|
||||
if(!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, &device)) {
|
||||
return NULL;
|
||||
}
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
int32_t n_regs = 0;
|
||||
int32_t n_spills = 0;
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
// create driver handles
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
// get allocated registers and spilled registers from the function
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
|
||||
if (shared > 49152 && shared_optin > 49152) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total, shared_static;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
|
||||
}
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
if(PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"cuda_utils\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cuda_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
src = self._generate_src()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = CacheManager(key)
|
||||
fname = "cuda_utils.so"
|
||||
if not cache.has_file(fname):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build("cuda_utils", src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache._make_path(fname))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
|
||||
|
||||
cuda_utils = None
|
||||
|
61
python/triton/tools/aot.py
Normal file
61
python/triton/tools/aot.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import argparse
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as libtriton
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# valid source and target formats
|
||||
VALID_FORMATS = ['llvm-ir', 'ptx', 'triton-ir', 'triton-gpu-ir']
|
||||
|
||||
# set up the argument parser
|
||||
# TODO: conditional requirements
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('src', help="Source file to compile")
|
||||
parser.add_argument('--target', required=True,
|
||||
help="Target format, one of: " + ', '.join(VALID_FORMATS))
|
||||
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
|
||||
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
|
||||
|
||||
# parse the args
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: clean-up and re-use triton.compiler primitive functions
|
||||
# check for validity of format arguments
|
||||
if args.target not in VALID_FORMATS:
|
||||
print("Invalid target format: " + args.target)
|
||||
exit(0)
|
||||
|
||||
# parse source file to MLIR module
|
||||
context = libtriton.ir.context()
|
||||
module = libtriton.ir.parse_mlir_module(args.src, context)
|
||||
module.context = context
|
||||
|
||||
# optimizer triton-ir
|
||||
module = triton.compiler.optimize_triton_ir(module)
|
||||
if args.target == 'triton-ir':
|
||||
print(module.str())
|
||||
exit(0)
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.make_tritongpu_ir(module, num_warps=4)
|
||||
module = triton.compiler.optimize_tritongpu_ir(module, num_stages=3)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
module = triton.compiler.make_llvm_ir(module)
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
exit(0)
|
||||
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
|
||||
# llvm-ir -> ptx
|
||||
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
||||
assert args.target == 'ptx'
|
||||
print(module)
|
Reference in New Issue
Block a user