[TESTING] Added infrastructure for executing TTGIR program and test for layout conversions (#885)
This commit is contained in:
@@ -1399,37 +1399,35 @@ struct BroadcastOpConversion
|
||||
Value result = op.result();
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto resultTy = result.getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(srcLayout && (srcLayout == resultLayout) &&
|
||||
"Unexpected layout of BroadcastOp");
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned rank = srcTy.getRank();
|
||||
assert(rank == resultTy.getRank());
|
||||
auto order = srcLayout.getOrder();
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
|
||||
SmallVector<int64_t> srcLogicalShape(2 * rank);
|
||||
SmallVector<unsigned> srcLogicalOrder(2 * rank);
|
||||
SmallVector<int64_t> resultLogicalShape(2 * rank);
|
||||
SmallVector<unsigned> broadcastDims;
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
|
||||
resultLayout.getThreadsPerWarp()[d] *
|
||||
resultLayout.getWarpsPerCTA()[d];
|
||||
unsigned resultShapePerCTA = triton::gpu::getSizePerThread(resultLayout)[d] *
|
||||
triton::gpu::getThreadsPerWarp(resultLayout)[d] *
|
||||
triton::gpu::getWarpsPerCTA(resultLayout)[d];
|
||||
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
|
||||
if (srcShape[d] != resultShape[d]) {
|
||||
assert(srcShape[d] == 1);
|
||||
broadcastDims.push_back(d);
|
||||
srcLogicalShape[d] = 1;
|
||||
srcLogicalShape[d + rank] =
|
||||
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
|
||||
std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]);
|
||||
} else {
|
||||
srcLogicalShape[d] = numCtas;
|
||||
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
srcLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
|
||||
}
|
||||
resultLogicalShape[d] = numCtas;
|
||||
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
resultLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
|
||||
|
||||
srcLogicalOrder[d] = order[d] + rank;
|
||||
srcLogicalOrder[d + rank] = order[d];
|
||||
|
@@ -163,7 +163,19 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<mlir::Type>(m, "type")
|
||||
.def("is_integer", &mlir::Type::isInteger)
|
||||
.def("is_fp16", &mlir::Type::isF16);
|
||||
.def("is_fp16", &mlir::Type::isF16)
|
||||
.def("__str__", [](mlir::Type &self) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
self.print(os);
|
||||
return os.str();
|
||||
});
|
||||
|
||||
py::class_<mlir::FunctionType>(m, "function_type")
|
||||
.def("param_types", [](mlir::FunctionType &self) {
|
||||
return std::vector<mlir::Type>(self.getInputs().begin(),
|
||||
self.getInputs().end());
|
||||
});
|
||||
|
||||
py::class_<mlir::Value>(m, "value")
|
||||
.def("set_attr",
|
||||
@@ -314,6 +326,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_function",
|
||||
[](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
|
||||
return self.lookupSymbol<mlir::FuncOp>(funcName);
|
||||
})
|
||||
.def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp {
|
||||
llvm::SmallVector<mlir::FuncOp> funcs;
|
||||
self.walk([&](mlir::FuncOp func) { funcs.push_back(func); });
|
||||
if (funcs.size() != 1)
|
||||
throw std::runtime_error("Expected a single function");
|
||||
return funcs[0];
|
||||
});
|
||||
|
||||
m.def(
|
||||
@@ -363,6 +382,7 @@ void init_triton_ir(py::module &&m) {
|
||||
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||
},
|
||||
ret::reference)
|
||||
.def_property_readonly("type", &mlir::FuncOp::getType)
|
||||
.def("reset_type", &mlir::FuncOp::setType);
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
@@ -1274,8 +1294,8 @@ void init_triton_ir(py::module &&m) {
|
||||
void init_triton_translation(py::module &m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
||||
auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
|
||||
m.def("get_shared_memory_size", [](mlir::ModuleOp mod) {
|
||||
auto shared = mod->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
|
||||
return shared.getInt();
|
||||
});
|
||||
|
||||
|
91
python/tests/test_backend.py
Normal file
91
python/tests/test_backend.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
import pytest
|
||||
from .test_core import numpy_random, to_triton
|
||||
|
||||
class MmaLayout:
|
||||
def __init__(self, version, warps_per_cta):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mma<{{version={self.version}, warpsPerCTA={self.warps_per_cta}}}>"
|
||||
|
||||
class BlockedLayout:
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
||||
self.sz_per_thread = str(size_per_thread)
|
||||
self.threads_per_warp = str(threads_per_warp)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.order = str(order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=2, warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=2, warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
pytest.skip()
|
||||
|
||||
|
||||
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
|
||||
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
|
||||
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
|
||||
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
|
||||
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
x = to_triton(numpy_random(shape, dtype_str=dtype))
|
||||
z = torch.empty_like(x)
|
||||
|
||||
# write the IR to a temporary file using mkstemp
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
kernel[(1,1,1)](x.data_ptr(), z.data_ptr())
|
||||
|
||||
assert torch.equal(z, x)
|
||||
|
@@ -16,7 +16,7 @@ def test_empty_kernel_cubin_compile():
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
kernel = triton.compile(empty_kernel,
|
||||
"*fp32,i32,i32",
|
||||
signature="*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256})
|
||||
|
||||
|
@@ -1028,7 +1028,7 @@ def binary_name_to_header_name(name):
|
||||
return f"{name}.h"
|
||||
|
||||
|
||||
def generate_launcher(identifier, constants, signature):
|
||||
def generate_launcher(constants, signature):
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
@@ -1184,6 +1184,9 @@ class CacheManager:
|
||||
def put(self, data, filename, binary=True):
|
||||
if not self.cache_dir:
|
||||
return
|
||||
binary = isinstance(data, bytes)
|
||||
if not binary:
|
||||
data = str(data)
|
||||
assert self.lock_path is not None
|
||||
filepath = self._make_path(filename)
|
||||
with FileLock(self.lock_path):
|
||||
@@ -1296,16 +1299,8 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
|
||||
cache_manager.put(data, file_name, True if isinstance(data, bytes) else data)
|
||||
return module, md5, True, False
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
if configs is None:
|
||||
configs = [instance_descriptor()]
|
||||
assert len(configs) == 1
|
||||
# cache manager
|
||||
name = fn.__name__
|
||||
#
|
||||
def make_stub(name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(signature, constants)
|
||||
so_cache_manager = CacheManager(so_cache_key)
|
||||
@@ -1313,57 +1308,129 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
# retrieve stub from cache if it exists
|
||||
if not so_cache_manager.has_file(so_name):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src = generate_launcher(name, constants, signature)
|
||||
src = generate_launcher(constants, signature)
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(fn.__name__, src_path, tmpdir)
|
||||
so = _build(name, src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
so_path = so_cache_manager._make_path(so_name)
|
||||
return so_cache_manager._make_path(so_name)
|
||||
|
||||
|
||||
def convert_type_repr(x):
|
||||
match = re.search('!tt\.ptr<(.*)>', x)
|
||||
if match is not None:
|
||||
return '*' + convert_type_repr(match.group(1))
|
||||
return x
|
||||
|
||||
def make_hash(fn, **kwargs):
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
|
||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
def compile(fn, **kwargs):
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
# if fn is not a JITFunction, then it
|
||||
# has to be a path to a file
|
||||
context = _triton.ir.context()
|
||||
asm, md5 = dict(), dict()
|
||||
constants = kwargs.get("constants", dict())
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
signature = kwargs["signature"]
|
||||
if configs is None:
|
||||
configs = [instance_descriptor()]
|
||||
assert len(configs) == 1
|
||||
kwargs["configs"] = configs
|
||||
name = fn.__name__
|
||||
first_stage = 0
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
kwargs["signature"] = signature
|
||||
else:
|
||||
assert isinstance(fn, str)
|
||||
name, ir = os.path.basename(fn).split(".")
|
||||
assert ir == "ttgir"
|
||||
asm[ir] = _triton.ir.parse_mlir_module(fn, context)
|
||||
function = asm[ir].get_single_function()
|
||||
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = 2
|
||||
|
||||
# cache manager
|
||||
so_path = make_stub(name, signature, constants)
|
||||
# create cache manager
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
fn_cache_manager = CacheManager(make_hash(fn, **kwargs))
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
else:
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
# initialize compilation params
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
device = kwargs.get("device", torch.cuda.current_device())
|
||||
# load metadata if any
|
||||
metadata = None
|
||||
if fn_cache_manager.has_file(f'{name}.json'):
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
context = _triton.ir.context()
|
||||
force_compile = False
|
||||
# ast -> triton-ir (or read from cache)
|
||||
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
|
||||
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
|
||||
# triton-ir -> triton-gpu-ir (or read from cache)
|
||||
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
|
||||
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
|
||||
# triton-gpu-ir -> llvm-ir (or read from cache)
|
||||
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
|
||||
run_if_found = lambda path: Path(path).read_bytes(),
|
||||
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
|
||||
if llvm_cached:
|
||||
shmem_size = metadata["shared"]
|
||||
else:
|
||||
shmem_size = _triton.get_shared_memory_size(ttgir)
|
||||
# llvm-ir -> ptx (or read from cache)
|
||||
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
|
||||
run_if_found = lambda path: Path(path).read_text(),
|
||||
run_if_not_found = lambda: llir_to_ptx(llir))
|
||||
# ptx -> cubin (or read from cache)
|
||||
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
|
||||
run_if_found = lambda path: Path(path).read_bytes(),
|
||||
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
|
||||
# dump new metadata
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
|
||||
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
|
||||
"llir": llir_md5,
|
||||
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast" : (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
llir_to_ptx),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, device))
|
||||
}
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir, (parse, compile) in list(stages.items())[first_stage:]:
|
||||
path = fn_cache_manager._make_path(f"{name}.{ir}")
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
fn_cache_manager.put(next_module, f"{name}.{ir}")
|
||||
if os.path.exists(path):
|
||||
metadata["ctime"][ir] = os.path.getctime(path)
|
||||
asm[ir] = next_module if ir == "cubin" else str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = ptx_get_kernel_name(next_module)
|
||||
module = next_module
|
||||
# write-back metadata
|
||||
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||
|
||||
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
|
||||
# return handle to compiled kernel
|
||||
return CompiledKernel(so_path, metadata, asm)
|
||||
|
||||
|
||||
@@ -1395,7 +1462,7 @@ class CompiledKernel:
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||
return
|
||||
return runner
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if 'sass' in self.asm:
|
||||
|
@@ -275,7 +275,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
device = 0
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
|
||||
self.cache[key] = bin
|
||||
|
Reference in New Issue
Block a user