[GENERAL] Removed deprecated driver files and added basic compatibility with rocm (#268)

- Removed driver module -- accelerator runtime is handled by pytorch
- Added basic support for ROCM based on @micmelesse 's PR -- now can execute empty kernel on AMD devices without any compile-time changes
- Now only using PREFER_SHARED for kernels when the size of shared memory is greater than 49k. Otherwise there can be poor L1 performance for broadcast tensors
This commit is contained in:
Philippe Tillet
2021-09-09 00:04:28 -07:00
committed by GitHub
parent 8bedcce9be
commit 94c83d30ce
47 changed files with 1376 additions and 30232 deletions

View File

@@ -1,7 +1,7 @@
#include "triton/codegen/pass.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/driver/stream.h"
#include "triton/codegen/target.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/ir/builder.h"
#include "triton/ir/dispatch.h"
#include "triton/ir/enums.h"
@@ -15,7 +15,9 @@
#include <pybind11/stl.h>
#include <regex>
#include <string>
#include <sstream>
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace py = pybind11;
namespace ir = triton::ir;
@@ -24,72 +26,213 @@ namespace drv = triton::driver;
/*****************************************************************************/
/* Python bindings for triton::driver */
/*****************************************************************************/
// information query
template<CUdevice_attribute attr>
int cuGetInfo(CUdevice device) {
int res;
drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
return res;
}
void init_triton_driver(py::module &&m) {
// base device
py::class_<drv::device>(m, "device");
// cuda device
py::class_<drv::cu_device, drv::device>(m, "cu_device")
.def(py::init([](int dev_id, bool take_ownership) {
CUdevice handle;
drv::dispatch::cuDeviceGet(&handle, dev_id);
return new drv::cu_device(handle, take_ownership);
}))
.def("max_shared_memory", [](drv::cu_device *self) {
return self->max_shared_memory();
})
.def("enable_peer_access", [](drv::cu_device *self, unsigned long long int peer_mem_ptr) {
self->enable_peer_access(peer_mem_ptr);
});
// host device
py::class_<drv::host_device, drv::device>(m, "host_device")
.def(py::init<>());
template<hipDeviceAttribute_t attr>
int hipGetInfo(hipDevice_t device) {
int res;
drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
return res;
}
// base stream
py::class_<drv::stream>(m, "stream");
// host stream
py::class_<drv::host_stream, drv::stream>(m, "host_stream")
.def(py::init<>());
// cuda stream
py::class_<drv::cu_stream, drv::stream>(m, "cu_stream")
// py doesn't support opaque pointer (e.g., CUstream) so
// we assume it has been converted to uint64_t
.def(py::init([](uint64_t handle, bool take_ownership) {
return std::unique_ptr<drv::cu_stream>(new drv::cu_stream((CUstream)handle, take_ownership));
}))
.def("enqueue", [](drv::cu_stream *self, drv::kernel *kernel,
size_t grid_0, size_t grid_1, size_t grid_2,
size_t block_0, size_t block_1, size_t block_2,
const std::string &args,
size_t shared_mem) {
return self->enqueue(kernel, {grid_0, grid_1, grid_2}, {block_0, block_1, block_2},
(void *)args.data(), args.size(), shared_mem);
});
enum backend_t {
HOST,
CUDA,
ROCM,
};
py::class_<drv::module>(m, "module");
void cu_enable_peer_access(uint64_t peer_ptr){
CUcontext context;
drv::dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, peer_ptr);
try {
drv::dispatch::cuCtxEnablePeerAccess(context, 0);
} catch (drv::exception::cuda::peer_access_already_enabled) {}
}
py::class_<drv::cu_module, drv::module>(m, "cu_module")
.def("ptx", &drv::cu_module::ptx)
.def("cubin", [](drv::cu_module *self) { return py::bytes(self->cubin()); })
.def("llir", &drv::cu_module::llir);
void host_enqueue(uint64_t stream, uint64_t kernel,
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
uint64_t block_0, uint64_t block_1, uint64_t block_2,
void* args_ptr, size_t args_size, int64_t shared_mem){
throw std::runtime_error("unsupported");
// auto hst = kernel->module()->hst();
// hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
// char* params = new char[args_size];
// std::memcpy((void*)params, (void*)args, args_size);
// for(size_t i = 0; i < grid[0]; i++)
// for(size_t j = 0; j < grid[1]; j++)
// for(size_t k = 0; k < grid[2]; k++)
// hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k)));
}
py::class_<drv::kernel>(m, "kernel");
void cu_enqueue(uint64_t stream, uint64_t kernel,
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
uint64_t block_0, uint64_t block_1, uint64_t block_2,
void* args_ptr, size_t args_size, int64_t shared_mem){
void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END
};
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (CUstream)stream, nullptr, config);
}
void hip_enqueue(uint64_t stream, uint64_t kernel,
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
uint64_t block_0, uint64_t block_1, uint64_t block_2,
void* args_ptr, size_t args_size, int64_t shared_mem) {
void *config[] = {
HIP_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
HIP_LAUNCH_PARAM_END
};
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (hipStream_t)stream, nullptr, config);
}
void init_triton_runtime(py::module &&m) {
// wrap backend_t
py::enum_<backend_t>(m, "backend")
.value("HOST", HOST)
.value("CUDA", CUDA)
.value("ROCM", ROCM)
.export_values();
// enable peer-to-peer
m.def("enable_peer_access", [](backend_t backend, uint64_t peer_ptr) {
if (backend != CUDA)
throw std::runtime_error("P2P only supported on CUDA devices!");
cu_enable_peer_access(peer_ptr);
}
);
// query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
return 0;
if(backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device);
if(backend == ROCM)
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
return -1;
});
// enqueue
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
uint64_t block_0, uint64_t block_1, uint64_t block_2,
const std::string &args, int64_t shared_mem){
void* args_ptr = (void*)args.data();
size_t args_size = args.size();
if(backend == HOST)
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == CUDA)
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == ROCM)
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
});
}
/*****************************************************************************/
/* Python bindings for triton::codegen */
/*****************************************************************************/
typedef std::map<std::string, std::string> asm_map_t;
std::tuple<uint64_t, uint64_t> cu_compile_llir(const std::string& name, size_t n_shared_bytes, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map, int cc, int version){
// LLVM-IR -> PTX
std::string ptx = drv::llir_to_ptx(llvm, cc, version);
asm_map["ptx"] = ptx;
// PTX -> Binary
CUmodule mod = drv::ptx_to_cumodule(ptx, cc);
// Handle to the kernel
CUfunction fun;
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// Dynamic shared memory
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
if(n_shared_bytes > 49152 && shared_optin > 49152){
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
int n_spills, n_reg;
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
// record asm
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
std::tuple<uint64_t, uint64_t> hip_compile_llir(const std::string& name, llvm::Module* llvm, uint64_t dev, asm_map_t& asm_map){
// LLVM-IR -> HSA-CO
std::string path = drv::llir_to_amdgpu(llvm, "gfx908");
// HSA-CO -> hipModule
hipModule_t mod = drv::amdgpu_to_hipmodule(path);
// Handle to the kernel
hipFunction_t fun;
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
// record asm
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
void init_triton_codegen(py::module &&m) {
m.def(
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) {
drv::module *mod;
drv::kernel *ker;
size_t shared_mem;
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem);
std::stringstream ss;
ir::print(ir, ss);
return std::make_tuple(mod, ker, shared_mem, ss.str());
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate
asm_map_t asm_map;
std::ostringstream ttir;
ir::print(ir, ttir);
asm_map["ttir"] = ttir.str();
llvm::LLVMContext ctx;
if(backend == CUDA){
// device properties
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
size_t cc = major*10 + minor;
int version;
drv::dispatch::cuDriverGetVersion(&version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
llvm::raw_string_ostream llir(asm_map["llir"]);
llir << *llvm;
llir.flush();
// LLVM-IR -> Bin
uint64_t mod, fun;
std::tie(mod, fun) = cu_compile_llir(name, n_shared_bytes, &*llvm, device, asm_map, cc, version);
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
}
if(backend == ROCM){
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::amd_cl_target target;
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
llvm::raw_string_ostream llir(asm_map["llir"]);
llir << *llvm;
llir.flush();
// LLVM-IR -> Bin
uint64_t mod, fun;
std::tie(mod, fun) = hip_compile_llir(name, &*llvm, device, asm_map);
return std::make_tuple(mod, fun, asm_map, n_shared_bytes);
}
},
py::return_value_policy::take_ownership);
}
@@ -302,7 +445,7 @@ void init_triton_ir(py::module &&m) {
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_driver(std::move(subm.def_submodule("driver")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir")));
init_triton_frontend(std::move(subm.def_submodule("frontend")));
}