print irs

This commit is contained in:
Michael Melesse
2022-10-28 17:46:52 +00:00
parent aa556d4f1b
commit 6e50f8b2c0
2 changed files with 10 additions and 5 deletions

View File

@@ -525,18 +525,22 @@ void init_triton_codegen(py::module &&m) {
int version;
// std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> AMDGCN LLVM-IR
std::cout << "ttir:" << std::endl;
std::cout << "\t" << ttir.str() << std::endl;
std::cout << "\t" << tmp << std::endl;
triton::codegen::amd_cl_target target;
auto llvm = triton::codegen::add_passes_to_emit_bin(
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
std::cout << "llir:" << std::endl;
std::cout << "\t" << llir.str() << std::endl;
llir.flush();
// LLVM-IR -> AMDGPU
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a");
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a");
amdgcn = std::get<0>(amdgpu);
hsaco_path = std::get<1>(amdgpu);
std::cout << "amdgcn:" << std::endl;
std::cout << "\t" << amdgcn << std::endl;
}
asm_map_t asm_map;
asm_map["ttir"] = py::cast(ttir.str());

View File

@@ -113,14 +113,14 @@ def check_type_supported(dtype):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, SIZE: tl.constexpr):
pass
# check_type_supported(dtype_x)
check_type_supported(dtype_x)
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
@@ -885,9 +885,9 @@ def test_f16_to_f8_rounding():
for dtype in dtypes_with_bfloat16
for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
pytest.skip(f"test_reduce1d currently has segfaults on ROCM")
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel
@triton.jit
@@ -946,6 +946,7 @@ reduce_configs2 = [
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
pytest.skip(f"test_reduce2d currently has segfaults on ROCM")
# triton kernel