[TESTING] Added infrastructure for executing TTGIR program and test for layout conversions (#885)

This commit is contained in:
Philippe Tillet
2022-11-18 07:46:45 +01:00
committed by GitHub
parent 9ea6135eb5
commit dab4855bdf
6 changed files with 243 additions and 67 deletions

View File

@@ -1028,7 +1028,7 @@ def binary_name_to_header_name(name):
return f"{name}.h"
def generate_launcher(identifier, constants, signature):
def generate_launcher(constants, signature):
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
@@ -1184,6 +1184,9 @@ class CacheManager:
def put(self, data, filename, binary=True):
if not self.cache_dir:
return
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
with FileLock(self.lock_path):
@@ -1296,16 +1299,8 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
cache_manager.put(data, file_name, True if isinstance(data, bytes) else data)
return module, md5, True, False
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
# we get the kernel, i.e. the first function generated in the module
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
# cache manager
name = fn.__name__
#
def make_stub(name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(signature, constants)
so_cache_manager = CacheManager(so_cache_key)
@@ -1313,57 +1308,129 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
# retrieve stub from cache if it exists
if not so_cache_manager.has_file(so_name):
with tempfile.TemporaryDirectory() as tmpdir:
src = generate_launcher(name, constants, signature)
src = generate_launcher(constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(fn.__name__, src_path, tmpdir)
so = _build(name, src_path, tmpdir)
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
so_path = so_cache_manager._make_path(so_name)
return so_cache_manager._make_path(so_name)
def convert_type_repr(x):
match = re.search('!tt\.ptr<(.*)>', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return x
def make_hash(fn, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, **kwargs):
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
# has to be a path to a file
context = _triton.ir.context()
asm, md5 = dict(), dict()
constants = kwargs.get("constants", dict())
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
signature = kwargs["signature"]
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
kwargs["configs"] = configs
name = fn.__name__
first_stage = 0
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
kwargs["signature"] = signature
else:
assert isinstance(fn, str)
name, ir = os.path.basename(fn).split(".")
assert ir == "ttgir"
asm[ir] = _triton.ir.parse_mlir_module(fn, context)
function = asm[ir].get_single_function()
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = 2
# cache manager
so_path = make_stub(name, signature, constants)
# create cache manager
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
fn_cache_manager = CacheManager(make_hash(fn, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
# initialize compilation params
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
# load metadata if any
metadata = None
if fn_cache_manager.has_file(f'{name}.json'):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
context = _triton.ir.context()
force_compile = False
# ast -> triton-ir (or read from cache)
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
# triton-ir -> triton-gpu-ir (or read from cache)
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
# triton-gpu-ir -> llvm-ir (or read from cache)
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
if llvm_cached:
shmem_size = metadata["shared"]
else:
shmem_size = _triton.get_shared_memory_size(ttgir)
# llvm-ir -> ptx (or read from cache)
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
run_if_found = lambda path: Path(path).read_text(),
run_if_not_found = lambda: llir_to_ptx(llir))
# ptx -> cubin (or read from cache)
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
# dump new metadata
kernel_name = ptx_get_kernel_name(ptx)
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
"llir": llir_md5,
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
# build compilation stages
stages = {
"ast" : (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs)),
"ptx": (lambda path: Path(path).read_text(),
llir_to_ptx),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, device))
}
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile) in list(stages.items())[first_stage:]:
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
# return handle to compiled kernel
return CompiledKernel(so_path, metadata, asm)
@@ -1395,7 +1462,7 @@ class CompiledKernel:
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return
return runner
def get_sass(self, fun=None):
if 'sass' in self.asm: