[FRONTEND] Bunch of fixes here and there (#436)

This commit is contained in:
Philippe Tillet
2022-01-20 10:55:59 -08:00
committed by GitHub
parent e0c5709cc8
commit 4c97d1ecd7
7 changed files with 71 additions and 39 deletions

View File

@@ -329,7 +329,6 @@ void init_triton_runtime(py::module &&m) {
// cuda will block if too many ops are enqueued
Py_BEGIN_ALLOW_THREADS
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
@@ -466,6 +465,9 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
asm_map_t &asm_map){
int n_shared_bytes;
Py_BEGIN_ALLOW_THREADS
llvm::LLVMContext ctx;
// device properties
CUdevice dev = (CUdevice)device;
@@ -476,7 +478,6 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
drv::dispatch::cuDriverGetVersion(&version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
@@ -492,6 +493,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
py::bytes bytes(cubin);
asm_map["cubin"] = bytes;
}
Py_END_ALLOW_THREADS
return std::make_tuple(name, asm_map, n_shared_bytes);
}