[FRONTEND] Make triton.compile work without a cuda context (#708)

This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
This commit is contained in:
Jason Ansel
2022-09-24 13:41:47 -07:00
committed by GitHub
parent 3ac929b48b
commit 998fd5f9af
4 changed files with 53 additions and 14 deletions

View File

@@ -436,7 +436,7 @@ typedef std::map<std::string, py::object> asm_map_t;
void init_triton_codegen(py::module &&m) {
m.def("compile_ttir",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs) {
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
py::gil_scoped_release allow_threads;
std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate
@@ -454,10 +454,12 @@ void init_triton_codegen(py::module &&m) {
name, triton::codegen::create_extern_lib(name, path));
}
// device properties
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
size_t cc = major*10 + minor;
if (cc == 0) {
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
cc = major*10 + minor;
}
int version;
std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> NVPTX LLVM-IR