[RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)
Also deprecates a couple of tests
This commit is contained in:
@@ -261,7 +261,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
},
|
},
|
||||||
ret::reference)
|
ret::reference)
|
||||||
.def("dump", [](mlir::OpState &self) { self->dump(); })
|
.def("dump", [](mlir::OpState &self) { self->dump(); })
|
||||||
.def("str",
|
.def("__str__",
|
||||||
[](mlir::OpState &self) -> std::string {
|
[](mlir::OpState &self) -> std::string {
|
||||||
std::string str;
|
std::string str;
|
||||||
llvm::raw_string_ostream os(str);
|
llvm::raw_string_ostream os(str);
|
||||||
@@ -1280,8 +1280,8 @@ void init_triton_translation(py::module &m) {
|
|||||||
using ret = py::return_value_policy;
|
using ret = py::return_value_policy;
|
||||||
|
|
||||||
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
||||||
return module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
|
auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
|
||||||
.getInt();
|
return shared.getInt();
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
@@ -1,33 +0,0 @@
|
|||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
|
|
||||||
offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
x1 = tl.load(x1_ptr + offsets, mask=offsets < n)
|
|
||||||
x2 = tl.load(x2_ptr + offsets, mask=offsets < n)
|
|
||||||
x3 = tl.load(x3_ptr + offsets, mask=offsets < n)
|
|
||||||
x4 = tl.load(x4_ptr + offsets, mask=offsets < n)
|
|
||||||
|
|
||||||
y1 = tl.sin(x1)
|
|
||||||
y2 = tl.libdevice.sin(x2)
|
|
||||||
y3 = tl.libdevice.div_rn(x3, x3)
|
|
||||||
y4 = tl.libdevice.fma_rd(x4, x4, x4)
|
|
||||||
|
|
||||||
tl.store(x1_ptr + offsets, y1, mask=offsets < n)
|
|
||||||
tl.store(x2_ptr + offsets, y2, mask=offsets < n)
|
|
||||||
tl.store(x3_ptr + offsets, y3, mask=offsets < n)
|
|
||||||
tl.store(x4_ptr + offsets, y4, mask=offsets < n)
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_kernel_cubin_compile():
|
|
||||||
kernel = triton.compiler._compile(math_kernel,
|
|
||||||
"*fp32,*fp32,*fp32,*fp32,i32",
|
|
||||||
device=0,
|
|
||||||
constants={"BLOCK_SIZE": 256},
|
|
||||||
output="ttgir") # "cubin"
|
|
||||||
assert kernel
|
|
||||||
# TODO: Check if the values are correct.
|
|
||||||
# TODO: Cover all the math operators
|
|
@@ -1,80 +0,0 @@
|
|||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: function with no arguments don't work
|
|
||||||
@triton.jit
|
|
||||||
def binop_type_check(X):
|
|
||||||
# 0d-tensor is not allowed.
|
|
||||||
# zero_0d = tl.zeros([], dtype=tl.float32)
|
|
||||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
|
||||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
|
||||||
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
|
||||||
|
|
||||||
# scalar + scalar -> scalar
|
|
||||||
a0 = 0.0 + 0.0
|
|
||||||
# # scalar + 0D -> 0D
|
|
||||||
# a1 = 0.0 + zero_0d
|
|
||||||
# a2 = zero_0d + 0.0
|
|
||||||
# scalar + 1D -> 1D
|
|
||||||
a3 = 0.0 + zero_1d
|
|
||||||
a4 = zero_1d + 0.0
|
|
||||||
# scalar + 2D -> 2D
|
|
||||||
a5 = 0.0 + zero_2d_22
|
|
||||||
a6 = zero_2d_22 + 0.0
|
|
||||||
|
|
||||||
# # 0D + 0D -> 0D
|
|
||||||
# b1 = zero_0d + zero_0d
|
|
||||||
# # 0D + 1D -> 1D
|
|
||||||
# b2 = zero_0d + zero_1d
|
|
||||||
# b3 = zero_1d + zero_0d
|
|
||||||
# # 0D + 2D -> 2D
|
|
||||||
# b4 = zero_0d + zero_2d_22
|
|
||||||
# b5 = zero_2d_22 + zero_0d
|
|
||||||
|
|
||||||
# 1D + 1D -> 1D
|
|
||||||
c1 = zero_1d + zero_1d
|
|
||||||
# 1D + 2D -> 2D
|
|
||||||
c2 = zero_1d + zero_2d_21
|
|
||||||
c3 = zero_1d + zero_2d_22
|
|
||||||
c4 = zero_2d_21 + zero_1d
|
|
||||||
c5 = zero_2d_22 + zero_1d
|
|
||||||
|
|
||||||
# 2D + 2D -> 2D
|
|
||||||
d1 = zero_2d_21 + zero_2d_21
|
|
||||||
d2 = zero_2d_22 + zero_2d_22
|
|
||||||
d3 = zero_2d_21 + zero_2d_22
|
|
||||||
d4 = zero_2d_22 + zero_2d_21
|
|
||||||
|
|
||||||
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
|
||||||
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
|
||||||
|
|
||||||
|
|
||||||
def test_binop_type_check():
|
|
||||||
kernel = triton.compiler._compile(binop_type_check,
|
|
||||||
signature="*fp32",
|
|
||||||
device=0,
|
|
||||||
output="ttir")
|
|
||||||
assert (kernel)
|
|
||||||
# TODO: Check types of the results
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def reduce_type_check(ptr):
|
|
||||||
v_32 = tl.load(ptr + tl.arange(0, 32))
|
|
||||||
v_scalar = tl.min(v_32, axis=0)
|
|
||||||
tl.store(ptr, v_scalar)
|
|
||||||
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
|
|
||||||
v_64 = tl.max(v_64x128, axis=1)
|
|
||||||
tl.store(ptr + tl.arange(0, 64), v_64)
|
|
||||||
v_128 = tl.max(v_64x128, axis=0)
|
|
||||||
tl.store(ptr + tl.arange(0, 128), v_128)
|
|
||||||
|
|
||||||
|
|
||||||
def test_reduce_type_check():
|
|
||||||
kernel = triton.compiler._compile(reduce_type_check,
|
|
||||||
signature="*fp32",
|
|
||||||
device=0,
|
|
||||||
output="ttir")
|
|
||||||
assert (kernel)
|
|
||||||
# TODO: Check types of the results
|
|
@@ -16,8 +16,9 @@ import tempfile
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from sysconfig import get_paths
|
from sysconfig import get_paths
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Callable, Dict, Tuple, Union
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
import setuptools
|
import setuptools
|
||||||
import torch
|
import torch
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
@@ -822,7 +823,10 @@ def kernel_suffix(signature, specialization):
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def make_triton_ir(fn, signature, specialization, constants):
|
def build_triton_ir(fn, signature, specialization, constants):
|
||||||
|
# canonicalize signature
|
||||||
|
if isinstance(signature, str):
|
||||||
|
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
context.load_triton()
|
context.load_triton()
|
||||||
# create kernel prototype
|
# create kernel prototype
|
||||||
@@ -852,7 +856,6 @@ def make_triton_ir(fn, signature, specialization, constants):
|
|||||||
ret.context = context
|
ret.context = context
|
||||||
return ret, generator
|
return ret, generator
|
||||||
|
|
||||||
|
|
||||||
def optimize_triton_ir(mod):
|
def optimize_triton_ir(mod):
|
||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
pm.enable_debug()
|
pm.enable_debug()
|
||||||
@@ -864,16 +867,13 @@ def optimize_triton_ir(mod):
|
|||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
def ast_to_ttir(fn, signature, specialization, constants):
|
||||||
|
mod, _ = build_triton_ir(fn, signature, specialization, constants)
|
||||||
|
return optimize_triton_ir(mod)
|
||||||
|
|
||||||
def make_tritongpu_ir(mod, num_warps):
|
def ttir_to_ttgir(mod, num_warps, num_stages):
|
||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||||
pm.run(mod)
|
|
||||||
return mod
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_tritongpu_ir(mod, num_stages):
|
|
||||||
pm = _triton.ir.pass_manager(mod.context)
|
|
||||||
pm.enable_debug()
|
pm.enable_debug()
|
||||||
# Get error in backend due to wrong conversion in expanding async-related instruction.
|
# Get error in backend due to wrong conversion in expanding async-related instruction.
|
||||||
# TODO[Superjomn]: Open it when fixed.
|
# TODO[Superjomn]: Open it when fixed.
|
||||||
@@ -897,11 +897,13 @@ def add_external_libs(mod, libs):
|
|||||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||||
|
|
||||||
|
|
||||||
def make_llvm_ir(mod):
|
def ttgir_to_llir(mod, extern_libs):
|
||||||
|
if extern_libs:
|
||||||
|
add_external_libs(mod, extern_libs)
|
||||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||||
|
|
||||||
|
|
||||||
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
|
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
|
||||||
'''
|
'''
|
||||||
Translate TritonGPU module to PTX code.
|
Translate TritonGPU module to PTX code.
|
||||||
:param mod: a TritonGPU dialect module
|
:param mod: a TritonGPU dialect module
|
||||||
@@ -909,16 +911,27 @@ def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str,
|
|||||||
- PTX code
|
- PTX code
|
||||||
- shared memory alloaction size
|
- shared memory alloaction size
|
||||||
'''
|
'''
|
||||||
|
if compute_capability is None:
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
compute_capability = torch.cuda.get_device_capability(device)
|
||||||
|
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||||
|
if ptx_version is None:
|
||||||
|
_, cuda_version = path_to_ptxas()
|
||||||
|
ptx_version = ptx_get_version(cuda_version)
|
||||||
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
||||||
|
|
||||||
|
|
||||||
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
|
|
||||||
|
def ptx_to_cubin(ptx: str, device: int):
|
||||||
'''
|
'''
|
||||||
Compile TritonGPU module to cubin.
|
Compile TritonGPU module to cubin.
|
||||||
:param ptx: ptx code
|
:param ptx: ptx code
|
||||||
:param device: CUDA device
|
:param device: CUDA device
|
||||||
:return: str
|
:return: str
|
||||||
'''
|
'''
|
||||||
|
ptxas, _ = path_to_ptxas()
|
||||||
|
compute_capability = torch.cuda.get_device_capability(device)
|
||||||
|
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
||||||
|
|
||||||
|
|
||||||
@@ -978,46 +991,6 @@ def path_to_ptxas():
|
|||||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||||
|
|
||||||
|
|
||||||
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
|
||||||
if isinstance(signature, str):
|
|
||||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
|
||||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
|
||||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
|
||||||
|
|
||||||
# triton-ir
|
|
||||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
|
||||||
module = optimize_triton_ir(module)
|
|
||||||
if output == "ttir":
|
|
||||||
return module.str()
|
|
||||||
|
|
||||||
# tritongpu-ir
|
|
||||||
module = make_tritongpu_ir(module, num_warps)
|
|
||||||
module = optimize_tritongpu_ir(module, num_stages)
|
|
||||||
if output == "ttgir":
|
|
||||||
return module.str()
|
|
||||||
|
|
||||||
if extern_libs:
|
|
||||||
add_external_libs(module, extern_libs)
|
|
||||||
|
|
||||||
# llvm-ir
|
|
||||||
llvm_ir = make_llvm_ir(module)
|
|
||||||
|
|
||||||
assert device >= 0, "device should be provided."
|
|
||||||
ptxas, cuda_version = path_to_ptxas()
|
|
||||||
compute_capability = torch.cuda.get_device_capability(device)
|
|
||||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
|
||||||
ptx_version = ptx_get_version(cuda_version)
|
|
||||||
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
|
|
||||||
shem_size = _triton.get_shared_memory_size(module)
|
|
||||||
kernel_name = ptx_get_kernel_name(ptx)
|
|
||||||
if output == "ptx":
|
|
||||||
return ptx, shem_size, kernel_name
|
|
||||||
|
|
||||||
cubin = make_cubin(ptx, ptxas, compute_capability)
|
|
||||||
if output == "cubin":
|
|
||||||
return cubin, ptx, shem_size, kernel_name
|
|
||||||
|
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
@@ -1306,6 +1279,23 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def read_or_execute(cache_manager, force_compile, file_name, metadata,
|
||||||
|
run_if_found: Callable[[str], bytes] = None,
|
||||||
|
run_if_not_found: Callable = None):
|
||||||
|
if not force_compile and cache_manager.has_file(file_name):
|
||||||
|
module = run_if_found(cache_manager._make_path(file_name))
|
||||||
|
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||||
|
md5 = hashlib.md5(data).hexdigest()
|
||||||
|
suffix = file_name.split(".")[1]
|
||||||
|
has_changed = metadata and md5 != metadata["md5"][suffix]
|
||||||
|
return module, md5, has_changed, True
|
||||||
|
module = run_if_not_found()
|
||||||
|
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||||
|
md5 = hashlib.md5(data).hexdigest()
|
||||||
|
cache_manager.put(data, file_name, True)
|
||||||
|
return module, md5, True, False
|
||||||
|
|
||||||
|
|
||||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||||
if isinstance(signature, str):
|
if isinstance(signature, str):
|
||||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||||
@@ -1329,29 +1319,56 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
|||||||
so = _build(fn.__name__, src_path, tmpdir)
|
so = _build(fn.__name__, src_path, tmpdir)
|
||||||
with open(so, "rb") as f:
|
with open(so, "rb") as f:
|
||||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||||
|
so_path = so_cache_manager._make_path(so_name)
|
||||||
# retrieve cached shared object if it exists
|
# create cache manager
|
||||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||||
fn_cache_manager = CacheManager(fn_cache_key)
|
fn_cache_manager = CacheManager(fn_cache_key)
|
||||||
ptx_name = f"{name}.ptx"
|
# load metadata if any
|
||||||
cubin_name = f"{name}.cubin"
|
metadata = None
|
||||||
data_name = f"{name}.json"
|
if fn_cache_manager.has_file(f'{name}.json'):
|
||||||
if not fn_cache_manager.has_file(cubin_name) or \
|
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||||
not fn_cache_manager.has_file(data_name) or \
|
metadata = json.load(f)
|
||||||
not fn_cache_manager.has_file(ptx_name):
|
context = _triton.ir.context()
|
||||||
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
force_compile = False
|
||||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
# ast -> triton-ir (or read from cache)
|
||||||
fn_cache_manager.put(cubin, cubin_name)
|
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
|
||||||
fn_cache_manager.put(ptx, ptx_name, binary=False)
|
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
|
||||||
|
# triton-ir -> triton-gpu-ir (or read from cache)
|
||||||
|
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
|
||||||
|
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
|
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
|
||||||
|
# triton-gpu-ir -> llvm-ir (or read from cache)
|
||||||
|
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
|
||||||
|
run_if_found = lambda path: Path(path).read_bytes(),
|
||||||
|
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
|
||||||
|
if llvm_cached:
|
||||||
|
shmem_size = metadata["shared"]
|
||||||
|
else:
|
||||||
|
shmem_size = _triton.get_shared_memory_size(ttgir)
|
||||||
|
# llvm-ir -> ptx (or read from cache)
|
||||||
|
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
|
||||||
|
run_if_found = lambda path: Path(path).read_text(),
|
||||||
|
run_if_not_found = lambda: llir_to_ptx(llir))
|
||||||
|
# ptx -> cubin (or read from cache)
|
||||||
|
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
|
||||||
|
run_if_found = lambda path: Path(path).read_bytes(),
|
||||||
|
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
|
||||||
|
# dump new metadata
|
||||||
|
kernel_name = ptx_get_kernel_name(ptx)
|
||||||
|
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
|
||||||
|
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
|
||||||
|
"llir": llir_md5,
|
||||||
|
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
|
||||||
|
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||||
|
|
||||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
|
||||||
|
return CompiledKernel(so_path, metadata, asm)
|
||||||
|
|
||||||
|
|
||||||
class CompiledKernel:
|
class CompiledKernel:
|
||||||
|
|
||||||
def __init__(self, fn_name, so_path, cache_dir):
|
def __init__(self, so_path, metadata, asm):
|
||||||
|
|
||||||
# initialize launcher
|
# initialize launcher
|
||||||
import importlib.util
|
import importlib.util
|
||||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||||
@@ -1359,18 +1376,11 @@ class CompiledKernel:
|
|||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
self.c_wrapper = getattr(mod, "launch")
|
self.c_wrapper = getattr(mod, "launch")
|
||||||
# initialize metadata
|
# initialize metadata
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
self.shared = metadata["shared"]
|
self.shared = metadata["shared"]
|
||||||
self.num_warps = metadata["num_warps"]
|
self.num_warps = metadata["num_warps"]
|
||||||
self.num_stages = metadata["num_stages"]
|
self.num_stages = metadata["num_stages"]
|
||||||
# initialize asm dict
|
# initialize asm dict
|
||||||
self.asm = dict()
|
self.asm = asm
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
|
||||||
self.asm["cubin"] = f.read()
|
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
|
||||||
self.asm["ptx"] = f.read()
|
|
||||||
|
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
global cuda_utils
|
global cuda_utils
|
||||||
if cuda_utils is None:
|
if cuda_utils is None:
|
||||||
|
@@ -38,14 +38,13 @@ if __name__ == '__main__':
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# triton-ir -> triton-gpu-ir
|
# triton-ir -> triton-gpu-ir
|
||||||
module = triton.compiler.make_tritongpu_ir(module, num_warps=4)
|
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3)
|
||||||
module = triton.compiler.optimize_tritongpu_ir(module, num_stages=3)
|
|
||||||
if args.target == 'triton-gpu-ir':
|
if args.target == 'triton-gpu-ir':
|
||||||
print(module.str())
|
print(module.str())
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# triton-gpu-ir -> llvm-ir
|
# triton-gpu-ir -> llvm-ir
|
||||||
module = triton.compiler.make_llvm_ir(module)
|
module = triton.compiler.ttgir_to_llir(module, extern_libs=None)
|
||||||
if args.target == 'llvm-ir':
|
if args.target == 'llvm-ir':
|
||||||
print(module)
|
print(module)
|
||||||
exit(0)
|
exit(0)
|
||||||
@@ -56,6 +55,6 @@ if __name__ == '__main__':
|
|||||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||||
|
|
||||||
# llvm-ir -> ptx
|
# llvm-ir -> ptx
|
||||||
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
||||||
assert args.target == 'ptx'
|
assert args.target == 'ptx'
|
||||||
print(module)
|
print(module)
|
||||||
|
Reference in New Issue
Block a user