diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 2a87fb8c9..d35245942 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -3864,7 +3864,7 @@ void generator::visit_function(ir::function* fn) { for(ir::basic_block *block: blocks) visit_basic_block(block); // finalize - std::cout << "\t// verifyFunction" << std::endl; + std::cout << "\t// finalize" << std::endl; finalize_function(fn); // verifyFunction diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a6e343b80..0badd9007 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -109,7 +109,7 @@ def check_type_supported(dtype): pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") -@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"]) +@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 diff --git a/python/triton/compiler.py b/python/triton/compiler.py index c04540ea4..f6ab062bf 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1250,6 +1250,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False): + print("compiler.py: compile") + print(f"\t{fn, signature, device, constants, num_warps, num_stages, extern_libs, configs, cc, warm_cache_only}") # we get the kernel, i.e. the first function generated in the module assert len(configs) == 1 # cache manager @@ -1272,14 +1274,23 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i # retrieve cached shared object if it exists fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages) fn_cache_manager = CacheManager(fn_cache_key) - ptx_name = f"{name}.ptx" - cubin_name = f"{name}.cubin" + if torch.version.hip is not None: + amdgpu_name = f"{name}.gcn" + hipmodule_name = f"{name}.hipmodule" + assembly_name = amdgpu_name + binary_name = hipmodule_name + else: + ptx_name = f"{name}.ptx" + cubin_name = f"{name}.cubin" + assembly_name = ptx_name + binary_name = cubin_name + data_name = f"{name}.json" ttir_name = f"{name}.ttir" llir_name = f"{name}.llir" - if not fn_cache_manager.has_file(cubin_name) or \ + if not fn_cache_manager.has_file(binary_name) or \ not fn_cache_manager.has_file(data_name) or \ - not fn_cache_manager.has_file(ptx_name) or \ + not fn_cache_manager.has_file(assembly_name) or \ not fn_cache_manager.has_file(ttir_name) or \ not fn_cache_manager.has_file(llir_name): @@ -1287,14 +1298,14 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "hipmodule", cc) # cache AMD assembly and binary - fn_cache_manager.put(asm["hipmodule"], cubin_name) - fn_cache_manager.put(asm["amdgpu"], ptx_name, binary=False) + fn_cache_manager.put(asm["hipmodule"], binary_name) + fn_cache_manager.put(asm["amdgpu"], assembly_name, binary=False) else: asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin", cc) # cache Nvidia assembly and binary - fn_cache_manager.put(asm["cubin"], cubin_name) - fn_cache_manager.put(asm["ptx"], ptx_name, binary=False) + fn_cache_manager.put(asm["cubin"], binary_name) + fn_cache_manager.put(asm["ptx"], assembly_name, binary=False) # cache triton and llvm ir fn_cache_manager.put(asm["ttir"], ttir_name, binary=False) @@ -1317,6 +1328,8 @@ class CompiledKernel: launch_exit_hook = None def __init__(self, fn_name, so_path, cache_dir, device): + print("compiler.py: CompiledKernel:__init__") + print(f"\t{self, fn_name, so_path, cache_dir, device}") # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location("launcher", so_path)