add prints

This commit is contained in:
Michael Melesse
2022-10-24 18:28:28 +00:00
parent 8da4323514
commit 9184b5cf65
3 changed files with 23 additions and 10 deletions

View File

@@ -3864,7 +3864,7 @@ void generator::visit_function(ir::function* fn) {
for(ir::basic_block *block: blocks)
visit_basic_block(block);
// finalize
std::cout << "\t// verifyFunction" << std::endl;
std::cout << "\t// finalize" << std::endl;
finalize_function(fn);
// verifyFunction

View File

@@ -109,7 +109,7 @@ def check_type_supported(dtype):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128

View File

@@ -1250,6 +1250,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
print("compiler.py: compile")
print(f"\t{fn, signature, device, constants, num_warps, num_stages, extern_libs, configs, cc, warm_cache_only}")
# we get the kernel, i.e. the first function generated in the module
assert len(configs) == 1
# cache manager
@@ -1272,14 +1274,23 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
# retrieve cached shared object if it exists
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
if torch.version.hip is not None:
amdgpu_name = f"{name}.gcn"
hipmodule_name = f"{name}.hipmodule"
assembly_name = amdgpu_name
binary_name = hipmodule_name
else:
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
assembly_name = ptx_name
binary_name = cubin_name
data_name = f"{name}.json"
ttir_name = f"{name}.ttir"
llir_name = f"{name}.llir"
if not fn_cache_manager.has_file(cubin_name) or \
if not fn_cache_manager.has_file(binary_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name) or \
not fn_cache_manager.has_file(assembly_name) or \
not fn_cache_manager.has_file(ttir_name) or \
not fn_cache_manager.has_file(llir_name):
@@ -1287,14 +1298,14 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "hipmodule", cc)
# cache AMD assembly and binary
fn_cache_manager.put(asm["hipmodule"], cubin_name)
fn_cache_manager.put(asm["amdgpu"], ptx_name, binary=False)
fn_cache_manager.put(asm["hipmodule"], binary_name)
fn_cache_manager.put(asm["amdgpu"], assembly_name, binary=False)
else:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
# cache Nvidia assembly and binary
fn_cache_manager.put(asm["cubin"], cubin_name)
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
fn_cache_manager.put(asm["cubin"], binary_name)
fn_cache_manager.put(asm["ptx"], assembly_name, binary=False)
# cache triton and llvm ir
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
@@ -1317,6 +1328,8 @@ class CompiledKernel:
launch_exit_hook = None
def __init__(self, fn_name, so_path, cache_dir, device):
print("compiler.py: CompiledKernel:__init__")
print(f"\t{self, fn_name, so_path, cache_dir, device}")
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)