Improve ROCm support. (#780)

- updates to support ROCm 5.2
- workarounds in tests where NV tools were used unconditionally
- implemented `get_num_blocks()` and `add_memfence()` for AMD GPU
- backported from history some atomics
- added bf16 support
- minor warnings cleanup
- added dockerfile to run on a ROCm enabled machine

Co-authored-by: B1tway <andrew.shukshov@gmail.com>
Co-authored-by: Andrey Shukshov <36711069+B1tway@users.noreply.github.com>
This commit is contained in:
Daniil Fukalov
2022-10-14 21:33:42 +03:00
committed by GitHub
parent 94d5c2e8b5
commit 406d03bfaf
17 changed files with 435 additions and 155 deletions

View File

@@ -97,7 +97,9 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (hipStream_t)stream, nullptr, config);
#ifdef DEBUG_ROCM
drv::dispatch::hipGetLastError();
#endif
}
void init_triton_runtime(py::module &&m) {
@@ -249,7 +251,7 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
llir.flush();
asm_map["llir"] = py::cast(tmp);
// LLVM-IR -> HSA-CO
std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
std::string path = drv::llir_to_amdgpu(llvm.get(), STRINGIFY(MI_GPU_ARCH));
asm_map["hsaco"] = py::cast(path);
return std::make_tuple(name, asm_map, n_shared_bytes);
}
@@ -266,14 +268,14 @@ void init_triton_codegen(py::module &&m) {
llvm::LLVMContext ctx;
if(backend == CUDA)
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
if(backend == ROCM)
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
assert(backend == ROCM);
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
}, py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
if(backend == CUDA)
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
if(backend == ROCM)
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
assert(backend == ROCM);
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
}, py::return_value_policy::take_ownership);
}