[FRONTEND] Backport new runtime from master
(#706)
This PR merges the new runtime back into the `triton-mlir` branch. This adds caching and just-in-time compilation functionality to the triton-mlir project, and paves the way for re-using tests from the master branch.
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -7,4 +9,5 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")
|
||||
X = torch.randn(1, device="cuda")
|
||||
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
|
||||
|
@@ -544,50 +544,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
/*****************************************************************************/
|
||||
typedef std::map<std::string, py::object> asm_map_t;
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
// ---------------------------------------
|
||||
|
||||
// CUDA
|
||||
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
|
||||
asm_map_t &asm_map,
|
||||
size_t n_shared_bytes,
|
||||
uint64_t dev) {
|
||||
// load assembly
|
||||
std::string assembly;
|
||||
if (asm_map.find("cubin") != asm_map.end())
|
||||
assembly = py::cast<std::string>(asm_map["cubin"]);
|
||||
else
|
||||
assembly = py::cast<std::string>(asm_map["ptx"]);
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
dev);
|
||||
if (n_shared_bytes > 49152 && shared_optin > 49152) {
|
||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
int shared_total, shared_static;
|
||||
int n_spills, n_reg;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
dev);
|
||||
drv::dispatch::cuFuncGetAttribute(&shared_static,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_spills,
|
||||
CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||
drv::dispatch::cuFuncSetAttribute(
|
||||
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
/*****************************************************************************/
|
||||
@@ -1728,11 +1684,41 @@ void init_triton_translation(py::module &m) {
|
||||
|
||||
m.def(
|
||||
"load_binary",
|
||||
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
|
||||
size_t n_shared_bytes, uint64_t dev) {
|
||||
[](const std::string &name, const std::string &data,
|
||||
size_t n_shared_bytes, uint64_t device) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
assert(backend == CUDA); // Only CUDA is supported now.
|
||||
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// create driver handles
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
|
||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
|
||||
fun);
|
||||
drv::dispatch::cuFuncGetAttribute(
|
||||
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_optin,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
|
||||
if (n_shared_bytes > 49152 && shared_optin > 49152) {
|
||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
int shared_total, shared_static;
|
||||
drv::dispatch::cuDeviceGetAttribute(
|
||||
&shared_total,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
|
||||
drv::dispatch::cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||
drv::dispatch::cuFuncSetAttribute(
|
||||
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static);
|
||||
}
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
|
||||
(uint64_t)n_spills);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
@@ -2,8 +2,9 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# TODO: function with no arguments don't work
|
||||
@triton.jit
|
||||
def cast_check():
|
||||
def cast_check(X):
|
||||
zero_0d = tl.zeros([], dtype=tl.float32)
|
||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||
@@ -48,9 +49,9 @@ def cast_check():
|
||||
|
||||
|
||||
def test_cast_check():
|
||||
kernel = triton.compile(cast_check,
|
||||
signature="",
|
||||
device=0,
|
||||
output="ttir")
|
||||
kernel = triton.compiler._compile(cast_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttgir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
||||
|
@@ -2,7 +2,6 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
@@ -16,30 +15,18 @@ def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
|
||||
def test_empty_kernel_cubin_compile():
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
cubin = triton.compile(empty_kernel,
|
||||
"*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256},
|
||||
output="cubin")
|
||||
kernel = triton.compile(empty_kernel,
|
||||
"*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256})
|
||||
|
||||
print('cubin size:', len(cubin))
|
||||
assert len(cubin) > 0
|
||||
assert len(kernel.asm["cubin"]) > 0
|
||||
|
||||
|
||||
def test_empty_kernel_launch():
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
|
||||
constants={"BLOCK": 256},
|
||||
num_warps=4,
|
||||
num_stages=3)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
|
||||
)
|
||||
|
||||
A = torch.zeros([1024], device="cuda")
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
grid=grid,
|
||||
device=device,
|
||||
X=A,
|
||||
stride_xm=256,
|
||||
BLOCK=tl.constexpr(256))
|
||||
empty_kernel[grid](X=A, stride_xm=256, BLOCK=256)
|
||||
|
@@ -23,11 +23,11 @@ def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
|
||||
def test_empty_kernel_cubin_compile():
|
||||
kernel = triton.compile(math_kernel,
|
||||
"*fp32,*fp32,*fp32,*fp32,i32",
|
||||
device=0,
|
||||
constants={"BLOCK_SIZE": 256},
|
||||
output="ttgir") # "cubin"
|
||||
kernel = triton.compiler._compile(math_kernel,
|
||||
"*fp32,*fp32,*fp32,*fp32,i32",
|
||||
device=0,
|
||||
constants={"BLOCK_SIZE": 256},
|
||||
output="ttgir") # "cubin"
|
||||
assert kernel
|
||||
# TODO: Check if the values are correct.
|
||||
# TODO: Cover all the math operators
|
||||
|
@@ -4,7 +4,6 @@ from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -40,29 +39,9 @@ def kernel(x_ptr, stride_xm,
|
||||
[2, 128, 64]
|
||||
])
|
||||
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
|
||||
# TODO: this is to initialize the cuda context since it is not properly
|
||||
# dealed with in the existing runtime, remove this when the runtime
|
||||
# is updated
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(kernel,
|
||||
"*fp32,i32,*fp32,i32",
|
||||
constants={"SIZE_M": SIZE_M,
|
||||
"SIZE_N": SIZE_N},
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=3)
|
||||
grid = lambda META: (1, )
|
||||
|
||||
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
device=device,
|
||||
grid=grid,
|
||||
x_ptr=x,
|
||||
stride_xm=x.stride(0),
|
||||
z_ptr=z,
|
||||
stride_zn=z.stride(0),
|
||||
SIZE_M=tl.constexpr(SIZE_M),
|
||||
SIZE_N=tl.constexpr(SIZE_N))
|
||||
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
|
||||
golden_z = torch.t(x)
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
@@ -3,7 +3,6 @@ from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
|
||||
def vecadd_no_scf_tester(num_warps, block_size):
|
||||
@@ -22,27 +21,13 @@ def vecadd_no_scf_tester(num_warps, block_size):
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
|
||||
constants={"BLOCK_SIZE_N": block_size},
|
||||
num_warps=num_warps,
|
||||
num_stages=3)
|
||||
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
assert x.shape.numel() % block_size == 0, "Only test load without mask here"
|
||||
grid = lambda EA: (x.shape.numel() // block_size,)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps)
|
||||
|
||||
runtime.launch_kernel(kernel=binary,
|
||||
grid=grid,
|
||||
device=device,
|
||||
x_ptr=x,
|
||||
y_ptr=y,
|
||||
z_ptr=z,
|
||||
BLOCK_SIZE_N=tl.constexpr(block_size))
|
||||
golden_z = x + y
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
@@ -7,8 +7,9 @@ __version__ = '2.0.0'
|
||||
import torch
|
||||
# submodules
|
||||
from .utils import *
|
||||
from .runtime import jit, Config, autotune, heuristics
|
||||
from .compiler import compile
|
||||
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
|
||||
from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from . import language
|
||||
from . import testing
|
||||
from . import ops
|
||||
|
@@ -1,10 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import sysconfig
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from sysconfig import get_paths
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
@@ -85,7 +101,7 @@ class enter_sub_region:
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = function_types
|
||||
@@ -94,6 +110,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.lscope = dict()
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.function_name = function_name
|
||||
self.is_kernel = is_kernel
|
||||
self.last_node = None
|
||||
self.builtins = {
|
||||
@@ -194,8 +211,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
|
||||
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -206,10 +222,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
if i in self.attributes:
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1])
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
@@ -746,21 +762,40 @@ class OutOfResources(Exception):
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
def kernel_suffix(signature, specialization):
|
||||
# suffix format:
|
||||
# <argid><'c' if equal to 1><'d' if divisible by 16>
|
||||
suffix = ''
|
||||
for i, _ in enumerate(signature):
|
||||
suffix += str(i)
|
||||
if i in specialization.equal_to_1:
|
||||
suffix += 'c'
|
||||
if i in specialization.divisible_by_16:
|
||||
suffix += 'd'
|
||||
return suffix
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, specialization, constants):
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
||||
if signature.replace(' ', '') != '':
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
else:
|
||||
arg_types = []
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
|
||||
constants = {cst_key(key): value for key, value in constants.items()}
|
||||
# visit kernel AST
|
||||
gscope = fn.__globals__.copy()
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True)
|
||||
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
|
||||
tys = list(signature.values())
|
||||
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
|
||||
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
|
||||
all_constants = constants.copy()
|
||||
all_constants.update(new_constants)
|
||||
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
|
||||
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except Exception as e:
|
||||
@@ -769,9 +804,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
raise e
|
||||
raise CompilationError(fn.src, node) from e
|
||||
ret = generator.module
|
||||
# module takes ownership of the MLIR context
|
||||
# module takes ownership of the context
|
||||
ret.context = context
|
||||
return ret
|
||||
return ret, generator
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
@@ -842,14 +877,21 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||
module = optimize_triton_ir(module)
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
@@ -865,6 +907,357 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=d
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
if output == "cubin":
|
||||
return cubin, shem_size, kernel_name
|
||||
return cubin, ptx, shem_size, kernel_name
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# compiler
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "CUdeviceptr"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
"i16": "int16_t",
|
||||
"i32": "int32_t",
|
||||
"i64": "int64_t",
|
||||
"u32": "uint32_t",
|
||||
"u64": "uint64_t",
|
||||
"fp32": "float",
|
||||
}[ty]
|
||||
|
||||
|
||||
def generate_name_initializer(signature):
|
||||
src = "int i = 0;\n"
|
||||
tys = signature.split(',')
|
||||
for i, ty in enumerate(tys):
|
||||
src
|
||||
|
||||
|
||||
def binary_name_to_header_name(name):
|
||||
if len(name) > 128:
|
||||
# avoid filename too long errors (filename limit is 255)
|
||||
name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest()
|
||||
return f"{name}.h"
|
||||
|
||||
|
||||
def generate_launcher(identifier, constants, signature):
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
if ty[0] == '*':
|
||||
return "PyObject*"
|
||||
return {
|
||||
'i1': 'int32_t',
|
||||
'i32': 'int32_t',
|
||||
'i64': 'int64_t',
|
||||
'u32': 'uint32_t',
|
||||
'u64': 'uint64_t',
|
||||
'fp32': 'float',
|
||||
'fp64': 'double',
|
||||
}[ty]
|
||||
|
||||
def format_of(ty):
|
||||
return {
|
||||
"PyObject*": "O",
|
||||
"float": "f",
|
||||
"double": "d",
|
||||
"long": "l",
|
||||
"uint32_t": "I",
|
||||
"int32_t": "i",
|
||||
"uint64_t": "K",
|
||||
"int64_t": "L",
|
||||
}[ty]
|
||||
|
||||
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <Python.h>
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {{0}};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
if (PyLong_Check(obj)) {{
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
return (CUdeviceptr)0;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
}}
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return (CUdeviceptr)0;
|
||||
}}
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
PyMODINIT_FUNC PyInit_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
|
||||
return src
|
||||
|
||||
|
||||
def default_cache_dir():
|
||||
return os.path.join(os.environ["HOME"], ".triton", "cache")
|
||||
|
||||
|
||||
class CacheManager:
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _make_path(self, filename):
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
|
||||
def has_file(self, filename):
|
||||
if not self.cache_dir:
|
||||
return False
|
||||
return os.path.exists(self._make_path(filename))
|
||||
|
||||
def put(self, data, filename, binary=True):
|
||||
if not self.cache_dir:
|
||||
return
|
||||
assert self.lock_path is not None
|
||||
filepath = self._make_path(filename)
|
||||
with FileLock(self.lock_path):
|
||||
# use tempfile to be robust against program interruptions
|
||||
mode = "wb" if binary else "w"
|
||||
with open(filepath + ".tmp", mode) as f:
|
||||
f.write(data)
|
||||
os.rename(filepath + ".tmp", filepath)
|
||||
|
||||
|
||||
# utilties for generating and compiling C wrappers
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def libcuda_dir():
|
||||
loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1]
|
||||
return os.path.dirname(loc)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
|
||||
|
||||
def _build(name, src, srcdir):
|
||||
cuda_lib_dir = libcuda_dir()
|
||||
cu_include_dir = "/usr/local/cuda/include"
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
cc = os.environ.get("CC")
|
||||
if cc is None:
|
||||
# TODO: support more things here.
|
||||
clang = shutil.which("clang")
|
||||
gcc = shutil.which("gcc")
|
||||
cc = gcc if gcc is not None else clang
|
||||
py_include_dir = get_paths()["include"]
|
||||
ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so])
|
||||
if ret == 0:
|
||||
return so
|
||||
# fallback on setuptools
|
||||
extra_compile_args = []
|
||||
library_dirs = [cuda_lib_dir]
|
||||
include_dirs = [srcdir, cu_include_dir]
|
||||
libraries = ['cuda']
|
||||
# extra arguments
|
||||
extra_link_args = []
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name=name,
|
||||
language='c',
|
||||
sources=[src],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=extra_compile_args + ['-O3'],
|
||||
extra_link_args=extra_link_args,
|
||||
library_dirs=library_dirs,
|
||||
libraries=libraries,
|
||||
)
|
||||
# build extension module
|
||||
args = ['build_ext']
|
||||
args.append('--build-temp=' + srcdir)
|
||||
args.append('--build-lib=' + srcdir)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name=name,
|
||||
ext_modules=[ext],
|
||||
script_args=args,
|
||||
)
|
||||
with quiet():
|
||||
setuptools.setup(**args)
|
||||
return so
|
||||
|
||||
|
||||
def make_so_cache_key(signature, constants):
|
||||
# Get unique key for the compiled code
|
||||
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
||||
key = f"{''.join(signature.values())}{constants}"
|
||||
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
return key
|
||||
|
||||
|
||||
def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages):
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
|
||||
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
return key
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
if configs is None:
|
||||
configs = [instance_descriptor()]
|
||||
assert len(configs) == 1
|
||||
# cache manager
|
||||
name = fn.__name__
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(signature, constants)
|
||||
so_cache_manager = CacheManager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
if not so_cache_manager.has_file(so_name):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src = generate_launcher(name, constants, signature)
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(fn.__name__, src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
|
||||
# retrieve cached shared object if it exists
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
data_name = f"{name}.json"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name):
|
||||
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(cubin, cubin_name)
|
||||
fn_cache_manager.put(ptx, ptx_name, binary=False)
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
|
||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir):
|
||||
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
# initialize asm dict
|
||||
self.asm = dict()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
def __getitem__(self, grid):
|
||||
def runner(*args, stream=None):
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||
return
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from .autotuner import Config, autotune, heuristics # noqa: F401
|
||||
from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401
|
||||
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
|
||||
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
|
||||
|
@@ -5,10 +5,11 @@ import time
|
||||
from typing import Dict
|
||||
|
||||
from ..testing import do_bench
|
||||
from .jit import KernelInterface
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
class Autotuner(KernelInterface):
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
@@ -21,7 +22,6 @@ class Autotuner:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = dict()
|
||||
self.kernel = kernel
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
@@ -41,6 +41,7 @@ class Autotuner:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
@@ -58,25 +59,16 @@ class Autotuner:
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
@@ -91,13 +83,41 @@ class Autotuner:
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||
num_warps=config.num_warps)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
for config in self.prune_configs(kwargs):
|
||||
self.fn.warmup(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
@@ -129,10 +149,8 @@ class Config:
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
@@ -143,12 +161,10 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
@@ -161,43 +177,39 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Heuristics(KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, values) -> None:
|
||||
self.fn = fn
|
||||
self.values = values
|
||||
self.arg_names = arg_names
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
for v, heur in self.values.items():
|
||||
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
|
||||
return self.fn.run(*args, **kwargs)
|
||||
|
||||
|
||||
def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
|
||||
|
||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
def fun(*args, **meta):
|
||||
for v, heur in values.items():
|
||||
assert v not in meta
|
||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||
return kernel(*args, **meta)
|
||||
return fun
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
return Heuristics(fn, fn.arg_names, values)
|
||||
|
||||
return decorator
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from __future__ import annotations, division
|
||||
|
||||
import ast
|
||||
import functools
|
||||
@@ -6,90 +6,24 @@ import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..compiler import compile
|
||||
from ..tools.disasm import extract
|
||||
from triton.utils import MockTensor
|
||||
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Binary
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
VALID_BACKENDS: List[str] = (
|
||||
_triton.runtime.backend.CUDA,
|
||||
)
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int):
|
||||
assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend)
|
||||
self.backend = backend
|
||||
self.name = name
|
||||
self.asm = asm
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
self.sass = ''
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.device = device
|
||||
self.shared_mem = bin.shared_mem
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
args, self.bin.shared_mem)
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if self.sass:
|
||||
return self.sass
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Kernel
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Kernel:
|
||||
|
||||
def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
|
||||
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
"""
|
||||
This AST visitor is used to find dependencies of a JITFunction. This can
|
||||
@@ -142,6 +76,8 @@ def version_key():
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
with open(triton.compiler.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
@@ -158,26 +94,213 @@ def version_key():
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class JITFunction:
|
||||
class KernelInterface:
|
||||
|
||||
def __getitem__(self, grid):
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
def launcher(*args, **kwargs):
|
||||
return self.run(*args, grid=grid, **kwargs)
|
||||
return launcher
|
||||
|
||||
|
||||
class JITFunction(KernelInterface):
|
||||
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
|
||||
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
@staticmethod
|
||||
def _key_of(arg):
|
||||
if hasattr(arg, "dtype"):
|
||||
return arg.dtype
|
||||
elif isinstance(arg, bool):
|
||||
return "i1"
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**31 <= arg and arg <= 2**32 - 1:
|
||||
return "u32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
return "i64"
|
||||
elif isinstance(arg, float):
|
||||
return 'fp32'
|
||||
elif arg is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||
elif isinstance(arg, int):
|
||||
return (arg % 16 == 0, arg == 1)
|
||||
return (arg is None, )
|
||||
|
||||
def _get_config(self, *args):
|
||||
def is_divisible_by_16(x):
|
||||
if hasattr(x, "data_ptr"):
|
||||
return x.data_ptr() % JITFunction.divisibility == 0
|
||||
elif isinstance(x, int):
|
||||
return x % JITFunction.divisibility == 0
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
@staticmethod
|
||||
def _type_of(key):
|
||||
if isinstance(key, (torch.dtype, triton.language.dtype)):
|
||||
ty = {
|
||||
torch.bool: 'i1',
|
||||
torch.float16: 'fp16',
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float32: 'fp32',
|
||||
torch.float64: 'fp64',
|
||||
torch.uint8: 'u8',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
|
||||
triton.language.uint8: 'u8',
|
||||
triton.language.uint16: 'u16',
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
triton.language.float8: 'fp8',
|
||||
}[key]
|
||||
return f'*{ty}'
|
||||
if key is None:
|
||||
return '*i8'
|
||||
assert isinstance(key, str)
|
||||
return key
|
||||
|
||||
def _make_signature(self, sig_key):
|
||||
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||
return signature
|
||||
|
||||
def _make_constants(self, constexpr_key):
|
||||
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
def __init__(self, module, name):
|
||||
self.module = module
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
specializations = []
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
||||
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
||||
f'else (False,)']
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
|
||||
key = (version_key, sig_key, constexpr_key, spec_key)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
if callable(grid):
|
||||
grid = grid({{{grid_args}}})
|
||||
grid_size = len(grid)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
try:
|
||||
bin = cache[key]
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args})
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
except KeyError:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
configs = self._get_config(*args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
|
||||
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||
# build kernel signature -- doesn't include specialized arguments
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||
# build stub signature -- includes arguments that are specialized
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
device = 0
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
|
||||
self.cache[key] = bin
|
||||
return bin
|
||||
return None
|
||||
"""
|
||||
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||
"cache": self.cache, "triton": triton, "torch": torch}
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
self.version = version
|
||||
self.inline = inline
|
||||
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.bin_cache = dict()
|
||||
# cache of just-in-time compiled kernels
|
||||
self.cache = dict()
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
@@ -186,16 +309,17 @@ class JITFunction:
|
||||
# annotations
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# constexprs
|
||||
# index of constexprs
|
||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||
# forward docs
|
||||
# launcher
|
||||
self.run = self._make_launcher()
|
||||
# re-use docs of wrapped function
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
self.__globals__ = fn.__globals__
|
||||
self.__module__ = fn.__module__
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
# TODO : hash should be attribute of `self`
|
||||
if self.hash is None:
|
||||
@@ -204,9 +328,12 @@ class JITFunction:
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
return self.hash
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Some unit tests do this, for example.
|
||||
# Our unit tests do this, for example.
|
||||
def parse(self):
|
||||
tree = ast.parse(self.src)
|
||||
assert isinstance(tree, ast.Module)
|
||||
@@ -217,167 +344,21 @@ class JITFunction:
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
def __setattr__(self, name, value):
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
if name == 'src':
|
||||
self.hash = None
|
||||
JITFunction.cache_key.fget.cache_clear()
|
||||
|
||||
def _init_kernel(self):
|
||||
if self.kernel is None:
|
||||
self.kernel = Kernel(self)
|
||||
for decorator in reversed(self.kernel_decorators):
|
||||
self.kernel = decorator(self.kernel)
|
||||
return self.kernel
|
||||
|
||||
def __getitem__(self, grid):
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
class Launcher:
|
||||
def __init__(self, kernel, grid):
|
||||
self.kernel = kernel
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
return Launcher(self._init_kernel(), grid)
|
||||
|
||||
def __repr__(self):
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0:
|
||||
return 16
|
||||
if N % 8 == 0:
|
||||
return 8
|
||||
if N % 4 == 0:
|
||||
return 4
|
||||
if N % 2 == 0:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
|
||||
class _KernelCache:
|
||||
def __init__(self,
|
||||
fn: JITFunction,
|
||||
fn_type: str,
|
||||
constants: Dict[str, Any],
|
||||
num_warps: int = 4,
|
||||
num_stages: int = 3):
|
||||
# hold the arguments for building a kernel
|
||||
self.fn = fn
|
||||
self.fn_type = fn_type
|
||||
self.constants = constants
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
|
||||
# kernel compilation cache
|
||||
self._binary_cache: Optional[LoadedBinary] = None
|
||||
|
||||
@property
|
||||
def binary_cache(self):
|
||||
return self._binary_cache
|
||||
|
||||
def set_binary_cache(self, binary: LoadedBinary):
|
||||
assert binary
|
||||
assert not self._binary_cache, "cannot set binary cache duplicately"
|
||||
self._binary_cache = binary
|
||||
|
||||
|
||||
def build_kernel(fn: JITFunction,
|
||||
fn_type: str,
|
||||
constants: Dict[str, Any],
|
||||
num_warps: int = 4,
|
||||
num_stages: int = 3,
|
||||
) -> _KernelCache:
|
||||
return _KernelCache(fn, fn_type, constants, num_warps, num_stages)
|
||||
|
||||
|
||||
torch_dtype_to_bytes = {
|
||||
torch.int8: 1,
|
||||
torch.uint8: 1,
|
||||
|
||||
torch.int16: 2,
|
||||
torch.short: 2,
|
||||
|
||||
torch.int: 4,
|
||||
torch.int32: 4,
|
||||
|
||||
torch.long: 8,
|
||||
torch.int64: 8,
|
||||
|
||||
torch.float32: 4,
|
||||
torch.float: 4,
|
||||
|
||||
torch.float16: 2,
|
||||
torch.half: 2,
|
||||
torch.bfloat16: 2,
|
||||
# free to extend
|
||||
}
|
||||
|
||||
|
||||
def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
|
||||
def is_tensor(arg):
|
||||
return hasattr(arg, 'data_ptr') # a torch.tensor
|
||||
|
||||
# prepare function args for compile
|
||||
kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
wargs = list(wargs)
|
||||
for i, pos in enumerate(sorted(kwargs)):
|
||||
wargs.insert(pos + i, kwargs[pos])
|
||||
assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs))
|
||||
|
||||
if not kernel.binary_cache:
|
||||
# build the kernel cache
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if i in kernel.fn.do_not_specialize:
|
||||
continue
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = pow2_divisor(arg)
|
||||
elif is_tensor(arg):
|
||||
assert arg.dtype in torch_dtype_to_bytes
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype]
|
||||
attributes[i] = divisibility
|
||||
|
||||
attributes_ = dict()
|
||||
for i, value in attributes.items():
|
||||
attributes_[kernel.fn.arg_names[i]] = value
|
||||
|
||||
cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin")
|
||||
assert cubin
|
||||
assert kernel_name
|
||||
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
|
||||
|
||||
asm = dict(cubin=cubin)
|
||||
binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps)
|
||||
loaded_binary = LoadedBinary(device, binary)
|
||||
kernel.set_binary_cache(loaded_binary)
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
stream = get_cuda_stream(device)
|
||||
|
||||
_triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names,
|
||||
stream, kernel.num_warps, kernel.num_stages, grid)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# `jit` decorator
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -386,16 +367,12 @@ def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
|
||||
def jit(*args, **kwargs):
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* objects within the triton.language package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
@@ -407,3 +384,32 @@ def jit(*args, **kwargs):
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
|
@@ -19,6 +19,24 @@ def next_power_of_2(n):
|
||||
return n
|
||||
|
||||
|
||||
class MockTensor:
|
||||
"""
|
||||
Can be used in place of real tensors when calling:
|
||||
kernel.warmup(MockTensor(torch.float32), ...)
|
||||
"""
|
||||
@staticmethod
|
||||
def wrap_dtype(arg):
|
||||
if isinstance(arg, torch.dtype):
|
||||
return MockTensor(arg)
|
||||
return arg
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
def data_ptr(self):
|
||||
return 0 # optimistically assumes multiple of 16
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
|
Reference in New Issue
Block a user