add prints
This commit is contained in:
@@ -3864,7 +3864,7 @@ void generator::visit_function(ir::function* fn) {
|
||||
for(ir::basic_block *block: blocks)
|
||||
visit_basic_block(block);
|
||||
// finalize
|
||||
std::cout << "\t// verifyFunction" << std::endl;
|
||||
std::cout << "\t// finalize" << std::endl;
|
||||
finalize_function(fn);
|
||||
|
||||
// verifyFunction
|
||||
|
@@ -109,7 +109,7 @@ def check_type_supported(dtype):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
|
||||
|
@@ -1250,6 +1250,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
|
||||
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
|
||||
print("compiler.py: compile")
|
||||
print(f"\t{fn, signature, device, constants, num_warps, num_stages, extern_libs, configs, cc, warm_cache_only}")
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
assert len(configs) == 1
|
||||
# cache manager
|
||||
@@ -1272,14 +1274,23 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
# retrieve cached shared object if it exists
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
if torch.version.hip is not None:
|
||||
amdgpu_name = f"{name}.gcn"
|
||||
hipmodule_name = f"{name}.hipmodule"
|
||||
assembly_name = amdgpu_name
|
||||
binary_name = hipmodule_name
|
||||
else:
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
assembly_name = ptx_name
|
||||
binary_name = cubin_name
|
||||
|
||||
data_name = f"{name}.json"
|
||||
ttir_name = f"{name}.ttir"
|
||||
llir_name = f"{name}.llir"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
if not fn_cache_manager.has_file(binary_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name) or \
|
||||
not fn_cache_manager.has_file(assembly_name) or \
|
||||
not fn_cache_manager.has_file(ttir_name) or \
|
||||
not fn_cache_manager.has_file(llir_name):
|
||||
|
||||
@@ -1287,14 +1298,14 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "hipmodule", cc)
|
||||
# cache AMD assembly and binary
|
||||
fn_cache_manager.put(asm["hipmodule"], cubin_name)
|
||||
fn_cache_manager.put(asm["amdgpu"], ptx_name, binary=False)
|
||||
fn_cache_manager.put(asm["hipmodule"], binary_name)
|
||||
fn_cache_manager.put(asm["amdgpu"], assembly_name, binary=False)
|
||||
else:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
# cache Nvidia assembly and binary
|
||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||
fn_cache_manager.put(asm["cubin"], binary_name)
|
||||
fn_cache_manager.put(asm["ptx"], assembly_name, binary=False)
|
||||
|
||||
# cache triton and llvm ir
|
||||
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
|
||||
@@ -1317,6 +1328,8 @@ class CompiledKernel:
|
||||
launch_exit_hook = None
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir, device):
|
||||
print("compiler.py: CompiledKernel:__init__")
|
||||
print(f"\t{self, fn_name, so_path, cache_dir, device}")
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
|
Reference in New Issue
Block a user