[FRONTEND] Backport new runtime from master
(#706)
This PR merges the new runtime back into the `triton-mlir` branch. This adds caching and just-in-time compilation functionality to the triton-mlir project, and paves the way for re-using tests from the master branch.
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
@@ -7,4 +9,5 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")
|
X = torch.randn(1, device="cuda")
|
||||||
|
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
|
||||||
|
@@ -544,50 +544,6 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
typedef std::map<std::string, py::object> asm_map_t;
|
typedef std::map<std::string, py::object> asm_map_t;
|
||||||
|
|
||||||
// ---------------------------------------
|
|
||||||
// Load provided assembly code into driver
|
|
||||||
// ---------------------------------------
|
|
||||||
|
|
||||||
// CUDA
|
|
||||||
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
|
|
||||||
asm_map_t &asm_map,
|
|
||||||
size_t n_shared_bytes,
|
|
||||||
uint64_t dev) {
|
|
||||||
// load assembly
|
|
||||||
std::string assembly;
|
|
||||||
if (asm_map.find("cubin") != asm_map.end())
|
|
||||||
assembly = py::cast<std::string>(asm_map["cubin"]);
|
|
||||||
else
|
|
||||||
assembly = py::cast<std::string>(asm_map["ptx"]);
|
|
||||||
// create driver handles
|
|
||||||
CUfunction fun;
|
|
||||||
CUmodule mod;
|
|
||||||
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
|
||||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
|
||||||
// set dynamic shared memory if necessary
|
|
||||||
int shared_optin;
|
|
||||||
drv::dispatch::cuDeviceGetAttribute(
|
|
||||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
|
||||||
dev);
|
|
||||||
if (n_shared_bytes > 49152 && shared_optin > 49152) {
|
|
||||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
|
||||||
int shared_total, shared_static;
|
|
||||||
int n_spills, n_reg;
|
|
||||||
drv::dispatch::cuDeviceGetAttribute(
|
|
||||||
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
|
||||||
dev);
|
|
||||||
drv::dispatch::cuFuncGetAttribute(&shared_static,
|
|
||||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
|
||||||
drv::dispatch::cuFuncGetAttribute(&n_spills,
|
|
||||||
CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
|
||||||
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
|
||||||
drv::dispatch::cuFuncSetAttribute(
|
|
||||||
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
|
||||||
shared_optin - shared_static);
|
|
||||||
}
|
|
||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
/* Python bindings for triton::ir */
|
/* Python bindings for triton::ir */
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
@@ -1728,11 +1684,41 @@ void init_triton_translation(py::module &m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"load_binary",
|
"load_binary",
|
||||||
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
|
[](const std::string &name, const std::string &data,
|
||||||
size_t n_shared_bytes, uint64_t dev) {
|
size_t n_shared_bytes, uint64_t device) {
|
||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
assert(backend == CUDA); // Only CUDA is supported now.
|
// create driver handles
|
||||||
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
CUfunction fun;
|
||||||
|
CUmodule mod;
|
||||||
|
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
|
||||||
|
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||||
|
// get allocated registers and spilled registers from the function
|
||||||
|
int n_regs = 0;
|
||||||
|
int n_spills = 0;
|
||||||
|
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS,
|
||||||
|
fun);
|
||||||
|
drv::dispatch::cuFuncGetAttribute(
|
||||||
|
&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||||
|
n_spills /= 4;
|
||||||
|
// set dynamic shared memory if necessary
|
||||||
|
int shared_optin;
|
||||||
|
drv::dispatch::cuDeviceGetAttribute(
|
||||||
|
&shared_optin,
|
||||||
|
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
|
||||||
|
if (n_shared_bytes > 49152 && shared_optin > 49152) {
|
||||||
|
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||||
|
int shared_total, shared_static;
|
||||||
|
drv::dispatch::cuDeviceGetAttribute(
|
||||||
|
&shared_total,
|
||||||
|
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
|
||||||
|
drv::dispatch::cuFuncGetAttribute(
|
||||||
|
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||||
|
drv::dispatch::cuFuncSetAttribute(
|
||||||
|
fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||||
|
shared_optin - shared_static);
|
||||||
|
}
|
||||||
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs,
|
||||||
|
(uint64_t)n_spills);
|
||||||
},
|
},
|
||||||
py::return_value_policy::take_ownership);
|
py::return_value_policy::take_ownership);
|
||||||
}
|
}
|
||||||
|
@@ -2,8 +2,9 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: function with no arguments don't work
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def cast_check():
|
def cast_check(X):
|
||||||
zero_0d = tl.zeros([], dtype=tl.float32)
|
zero_0d = tl.zeros([], dtype=tl.float32)
|
||||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||||
@@ -48,9 +49,9 @@ def cast_check():
|
|||||||
|
|
||||||
|
|
||||||
def test_cast_check():
|
def test_cast_check():
|
||||||
kernel = triton.compile(cast_check,
|
kernel = triton.compiler._compile(cast_check,
|
||||||
signature="",
|
signature="*fp32",
|
||||||
device=0,
|
device=0,
|
||||||
output="ttir")
|
output="ttgir")
|
||||||
assert (kernel)
|
assert (kernel)
|
||||||
# TODO: Check types of the results
|
# TODO: Check types of the results
|
||||||
|
@@ -2,7 +2,6 @@ import torch
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import triton.runtime as runtime
|
|
||||||
|
|
||||||
# trigger the torch.device implicitly to ensure cuda context initialization
|
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||||
torch.zeros([10], device=torch.device('cuda'))
|
torch.zeros([10], device=torch.device('cuda'))
|
||||||
@@ -16,30 +15,18 @@ def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
|
|||||||
def test_empty_kernel_cubin_compile():
|
def test_empty_kernel_cubin_compile():
|
||||||
|
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
cubin = triton.compile(empty_kernel,
|
kernel = triton.compile(empty_kernel,
|
||||||
"*fp32,i32,i32",
|
"*fp32,i32,i32",
|
||||||
device=device,
|
device=device,
|
||||||
constants={"BLOCK": 256},
|
constants={"BLOCK": 256})
|
||||||
output="cubin")
|
|
||||||
|
|
||||||
print('cubin size:', len(cubin))
|
assert len(kernel.asm["cubin"]) > 0
|
||||||
assert len(cubin) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_kernel_launch():
|
def test_empty_kernel_launch():
|
||||||
device = torch.cuda.current_device()
|
|
||||||
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
|
|
||||||
constants={"BLOCK": 256},
|
|
||||||
num_warps=4,
|
|
||||||
num_stages=3)
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
|
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
|
||||||
)
|
)
|
||||||
|
|
||||||
A = torch.zeros([1024], device="cuda")
|
A = torch.zeros([1024], device="cuda")
|
||||||
runtime.launch_kernel(kernel=binary,
|
empty_kernel[grid](X=A, stride_xm=256, BLOCK=256)
|
||||||
grid=grid,
|
|
||||||
device=device,
|
|
||||||
X=A,
|
|
||||||
stride_xm=256,
|
|
||||||
BLOCK=tl.constexpr(256))
|
|
||||||
|
@@ -23,7 +23,7 @@ def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
|
|||||||
|
|
||||||
|
|
||||||
def test_empty_kernel_cubin_compile():
|
def test_empty_kernel_cubin_compile():
|
||||||
kernel = triton.compile(math_kernel,
|
kernel = triton.compiler._compile(math_kernel,
|
||||||
"*fp32,*fp32,*fp32,*fp32,i32",
|
"*fp32,*fp32,*fp32,*fp32,i32",
|
||||||
device=0,
|
device=0,
|
||||||
constants={"BLOCK_SIZE": 256},
|
constants={"BLOCK_SIZE": 256},
|
||||||
|
@@ -4,7 +4,6 @@ from torch.testing import assert_allclose
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import triton.runtime as runtime
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -40,29 +39,9 @@ def kernel(x_ptr, stride_xm,
|
|||||||
[2, 128, 64]
|
[2, 128, 64]
|
||||||
])
|
])
|
||||||
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
|
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
|
||||||
# TODO: this is to initialize the cuda context since it is not properly
|
|
||||||
# dealed with in the existing runtime, remove this when the runtime
|
|
||||||
# is updated
|
|
||||||
torch.zeros([10], device=torch.device('cuda'))
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
binary = runtime.build_kernel(kernel,
|
|
||||||
"*fp32,i32,*fp32,i32",
|
|
||||||
constants={"SIZE_M": SIZE_M,
|
|
||||||
"SIZE_N": SIZE_N},
|
|
||||||
num_warps=NUM_WARPS,
|
|
||||||
num_stages=3)
|
|
||||||
grid = lambda META: (1, )
|
grid = lambda META: (1, )
|
||||||
|
|
||||||
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
|
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
|
||||||
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
||||||
runtime.launch_kernel(kernel=binary,
|
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
|
||||||
device=device,
|
|
||||||
grid=grid,
|
|
||||||
x_ptr=x,
|
|
||||||
stride_xm=x.stride(0),
|
|
||||||
z_ptr=z,
|
|
||||||
stride_zn=z.stride(0),
|
|
||||||
SIZE_M=tl.constexpr(SIZE_M),
|
|
||||||
SIZE_N=tl.constexpr(SIZE_N))
|
|
||||||
golden_z = torch.t(x)
|
golden_z = torch.t(x)
|
||||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
@@ -3,7 +3,6 @@ from torch.testing import assert_allclose
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import triton.runtime as runtime
|
|
||||||
|
|
||||||
|
|
||||||
def vecadd_no_scf_tester(num_warps, block_size):
|
def vecadd_no_scf_tester(num_warps, block_size):
|
||||||
@@ -22,27 +21,13 @@ def vecadd_no_scf_tester(num_warps, block_size):
|
|||||||
z_ptrs = z_ptr + offset
|
z_ptrs = z_ptr + offset
|
||||||
tl.store(z_ptrs, z)
|
tl.store(z_ptrs, z)
|
||||||
|
|
||||||
torch.zeros([10], device=torch.device('cuda'))
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
|
|
||||||
constants={"BLOCK_SIZE_N": block_size},
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=3)
|
|
||||||
|
|
||||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
assert x.shape.numel() % block_size == 0, "Only test load without mask here"
|
|
||||||
grid = lambda EA: (x.shape.numel() // block_size,)
|
grid = lambda EA: (x.shape.numel() // block_size,)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps)
|
||||||
|
|
||||||
runtime.launch_kernel(kernel=binary,
|
|
||||||
grid=grid,
|
|
||||||
device=device,
|
|
||||||
x_ptr=x,
|
|
||||||
y_ptr=y,
|
|
||||||
z_ptr=z,
|
|
||||||
BLOCK_SIZE_N=tl.constexpr(block_size))
|
|
||||||
golden_z = x + y
|
golden_z = x + y
|
||||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
@@ -7,8 +7,9 @@ __version__ = '2.0.0'
|
|||||||
import torch
|
import torch
|
||||||
# submodules
|
# submodules
|
||||||
from .utils import *
|
from .utils import *
|
||||||
from .runtime import jit, Config, autotune, heuristics
|
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
|
||||||
from .compiler import compile
|
from .runtime.jit import jit
|
||||||
|
from .compiler import compile, CompilationError
|
||||||
from . import language
|
from . import language
|
||||||
from . import testing
|
from . import testing
|
||||||
from . import ops
|
from . import ops
|
||||||
|
@@ -1,10 +1,26 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import sysconfig
|
||||||
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import namedtuple
|
||||||
|
from sysconfig import get_paths
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
|
import setuptools
|
||||||
|
import torch
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
@@ -85,7 +101,7 @@ class enter_sub_region:
|
|||||||
|
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
|
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()):
|
||||||
self.builder = _triton.ir.builder(context)
|
self.builder = _triton.ir.builder(context)
|
||||||
self.module = self.builder.create_module() if module is None else module
|
self.module = self.builder.create_module() if module is None else module
|
||||||
self.function_ret_types = function_types
|
self.function_ret_types = function_types
|
||||||
@@ -94,6 +110,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.lscope = dict()
|
self.lscope = dict()
|
||||||
self.attributes = attributes
|
self.attributes = attributes
|
||||||
self.constants = constants
|
self.constants = constants
|
||||||
|
self.function_name = function_name
|
||||||
self.is_kernel = is_kernel
|
self.is_kernel = is_kernel
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
self.builtins = {
|
self.builtins = {
|
||||||
@@ -194,8 +211,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||||
self.visit(init_node)
|
self.visit(init_node)
|
||||||
# initialize function
|
# initialize function
|
||||||
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
|
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
|
||||||
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
|
|
||||||
self.module.push_back(fn)
|
self.module.push_back(fn)
|
||||||
entry = fn.add_entry_block()
|
entry = fn.add_entry_block()
|
||||||
arg_values = []
|
arg_values = []
|
||||||
@@ -206,10 +222,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if not isinstance(cst, triton.language.constexpr):
|
if not isinstance(cst, triton.language.constexpr):
|
||||||
cst = triton.language.constexpr(self.constants[i])
|
cst = triton.language.constexpr(self.constants[i])
|
||||||
arg_values.append(cst)
|
arg_values.append(cst)
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
pass
|
|
||||||
if i in self.attributes:
|
if i in self.attributes:
|
||||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
|
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1])
|
||||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
@@ -746,21 +762,40 @@ class OutOfResources(Exception):
|
|||||||
return (type(self), (self.required, self.limit, self.name))
|
return (type(self), (self.required, self.limit, self.name))
|
||||||
|
|
||||||
|
|
||||||
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
def kernel_suffix(signature, specialization):
|
||||||
|
# suffix format:
|
||||||
|
# <argid><'c' if equal to 1><'d' if divisible by 16>
|
||||||
|
suffix = ''
|
||||||
|
for i, _ in enumerate(signature):
|
||||||
|
suffix += str(i)
|
||||||
|
if i in specialization.equal_to_1:
|
||||||
|
suffix += 'c'
|
||||||
|
if i in specialization.divisible_by_16:
|
||||||
|
suffix += 'd'
|
||||||
|
return suffix
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def make_triton_ir(fn, signature, specialization, constants):
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
context.load_triton()
|
context.load_triton()
|
||||||
# create kernel prototype
|
# create kernel prototype
|
||||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
|
||||||
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
constants = {cst_key(key): value for key, value in constants.items()}
|
||||||
if signature.replace(' ', '') != '':
|
|
||||||
arg_types = signature.replace(' ', '').split(',')
|
|
||||||
arg_types = [str_to_ty(x) for x in arg_types]
|
|
||||||
else:
|
|
||||||
arg_types = []
|
|
||||||
prototype = triton.language.function_type([], arg_types)
|
|
||||||
# visit kernel AST
|
# visit kernel AST
|
||||||
gscope = fn.__globals__.copy()
|
gscope = fn.__globals__.copy()
|
||||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True)
|
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
|
||||||
|
tys = list(signature.values())
|
||||||
|
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
|
||||||
|
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
|
||||||
|
all_constants = constants.copy()
|
||||||
|
all_constants.update(new_constants)
|
||||||
|
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
|
||||||
|
|
||||||
|
prototype = triton.language.function_type([], arg_types)
|
||||||
|
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
|
||||||
try:
|
try:
|
||||||
generator.visit(fn.parse())
|
generator.visit(fn.parse())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -769,9 +804,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
|||||||
raise e
|
raise e
|
||||||
raise CompilationError(fn.src, node) from e
|
raise CompilationError(fn.src, node) from e
|
||||||
ret = generator.module
|
ret = generator.module
|
||||||
# module takes ownership of the MLIR context
|
# module takes ownership of the context
|
||||||
ret.context = context
|
ret.context = context
|
||||||
return ret
|
return ret, generator
|
||||||
|
|
||||||
|
|
||||||
def optimize_triton_ir(mod):
|
def optimize_triton_ir(mod):
|
||||||
@@ -842,14 +877,21 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
|||||||
return line.split()[-1]
|
return line.split()[-1]
|
||||||
|
|
||||||
|
|
||||||
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
|
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||||
|
|
||||||
|
|
||||||
|
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||||
|
if isinstance(signature, str):
|
||||||
|
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||||
|
|
||||||
# triton-ir
|
# triton-ir
|
||||||
module = make_triton_ir(fn, signature, constants, attributes)
|
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||||
module = optimize_triton_ir(module)
|
module = optimize_triton_ir(module)
|
||||||
if output == "ttir":
|
if output == "ttir":
|
||||||
return module.str()
|
return module.str()
|
||||||
|
|
||||||
# tritongpu-ir
|
# tritongpu-ir
|
||||||
module = make_tritongpu_ir(module, num_warps)
|
module = make_tritongpu_ir(module, num_warps)
|
||||||
module = optimize_tritongpu_ir(module, num_stages)
|
module = optimize_tritongpu_ir(module, num_stages)
|
||||||
@@ -865,6 +907,357 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=d
|
|||||||
|
|
||||||
cubin = make_cubin(ptx, device)
|
cubin = make_cubin(ptx, device)
|
||||||
if output == "cubin":
|
if output == "cubin":
|
||||||
return cubin, shem_size, kernel_name
|
return cubin, ptx, shem_size, kernel_name
|
||||||
|
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# compiler
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def ty_to_cpp(ty):
|
||||||
|
if ty[0] == '*':
|
||||||
|
return "CUdeviceptr"
|
||||||
|
return {
|
||||||
|
"i1": "int32_t",
|
||||||
|
"i8": "int8_t",
|
||||||
|
"i16": "int16_t",
|
||||||
|
"i32": "int32_t",
|
||||||
|
"i64": "int64_t",
|
||||||
|
"u32": "uint32_t",
|
||||||
|
"u64": "uint64_t",
|
||||||
|
"fp32": "float",
|
||||||
|
}[ty]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_name_initializer(signature):
|
||||||
|
src = "int i = 0;\n"
|
||||||
|
tys = signature.split(',')
|
||||||
|
for i, ty in enumerate(tys):
|
||||||
|
src
|
||||||
|
|
||||||
|
|
||||||
|
def binary_name_to_header_name(name):
|
||||||
|
if len(name) > 128:
|
||||||
|
# avoid filename too long errors (filename limit is 255)
|
||||||
|
name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest()
|
||||||
|
return f"{name}.h"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launcher(identifier, constants, signature):
|
||||||
|
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||||
|
|
||||||
|
def _extracted_type(ty):
|
||||||
|
if ty[0] == '*':
|
||||||
|
return "PyObject*"
|
||||||
|
return {
|
||||||
|
'i1': 'int32_t',
|
||||||
|
'i32': 'int32_t',
|
||||||
|
'i64': 'int64_t',
|
||||||
|
'u32': 'uint32_t',
|
||||||
|
'u64': 'uint64_t',
|
||||||
|
'fp32': 'float',
|
||||||
|
'fp64': 'double',
|
||||||
|
}[ty]
|
||||||
|
|
||||||
|
def format_of(ty):
|
||||||
|
return {
|
||||||
|
"PyObject*": "O",
|
||||||
|
"float": "f",
|
||||||
|
"double": "d",
|
||||||
|
"long": "l",
|
||||||
|
"uint32_t": "I",
|
||||||
|
"int32_t": "i",
|
||||||
|
"uint64_t": "K",
|
||||||
|
"int64_t": "L",
|
||||||
|
}[ty]
|
||||||
|
|
||||||
|
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||||
|
|
||||||
|
# generate glue code
|
||||||
|
src = f"""
|
||||||
|
#include \"cuda.h\"
|
||||||
|
#include <Python.h>
|
||||||
|
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||||
|
{{
|
||||||
|
if (code != CUDA_SUCCESS)
|
||||||
|
{{
|
||||||
|
const char* prefix = "Triton Error [CUDA]: ";
|
||||||
|
const char* str;
|
||||||
|
cuGetErrorString(code, &str);
|
||||||
|
char err[1024] = {{0}};
|
||||||
|
strcat(err, prefix);
|
||||||
|
strcat(err, str);
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, err);
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||||
|
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||||
|
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||||
|
if(gridX*gridY*gridZ > 0){{
|
||||||
|
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||||
|
if (PyLong_Check(obj)) {{
|
||||||
|
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||||
|
}}
|
||||||
|
if (obj == Py_None) {{
|
||||||
|
return (CUdeviceptr)0;
|
||||||
|
}}
|
||||||
|
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||||
|
if(ptr){{
|
||||||
|
PyObject *empty_tuple = PyTuple_New(0);
|
||||||
|
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||||
|
Py_DECREF(empty_tuple);
|
||||||
|
Py_DECREF(ptr);
|
||||||
|
if (!PyLong_Check(ret)) {{
|
||||||
|
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||||
|
}}
|
||||||
|
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
|
||||||
|
}}
|
||||||
|
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||||
|
return (CUdeviceptr)0;
|
||||||
|
}}
|
||||||
|
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||||
|
int gridX, gridY, gridZ;
|
||||||
|
uint64_t _stream;
|
||||||
|
uint64_t _function;
|
||||||
|
int num_warps;
|
||||||
|
int shared_memory;
|
||||||
|
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||||
|
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||||
|
return NULL;
|
||||||
|
}}
|
||||||
|
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||||
|
if(PyErr_Occurred()) {{
|
||||||
|
return NULL;
|
||||||
|
}}
|
||||||
|
// return None
|
||||||
|
Py_INCREF(Py_None);
|
||||||
|
return Py_None;
|
||||||
|
}}
|
||||||
|
static PyMethodDef ModuleMethods[] = {{
|
||||||
|
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||||
|
{{NULL, NULL, 0, NULL}} // sentinel
|
||||||
|
}};
|
||||||
|
static struct PyModuleDef ModuleDef = {{
|
||||||
|
PyModuleDef_HEAD_INIT,
|
||||||
|
\"launcher\",
|
||||||
|
NULL, //documentation
|
||||||
|
-1, //size
|
||||||
|
ModuleMethods
|
||||||
|
}};
|
||||||
|
PyMODINIT_FUNC PyInit_launcher(void) {{
|
||||||
|
PyObject *m = PyModule_Create(&ModuleDef);
|
||||||
|
if(m == NULL) {{
|
||||||
|
return NULL;
|
||||||
|
}}
|
||||||
|
PyModule_AddFunctions(m, ModuleMethods);
|
||||||
|
return m;
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
def default_cache_dir():
|
||||||
|
return os.path.join(os.environ["HOME"], ".triton", "cache")
|
||||||
|
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
|
||||||
|
def __init__(self, key):
|
||||||
|
self.key = key
|
||||||
|
self.lock_path = None
|
||||||
|
# create cache directory if it doesn't exist
|
||||||
|
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||||
|
if self.cache_dir:
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def _make_path(self, filename):
|
||||||
|
return os.path.join(self.cache_dir, filename)
|
||||||
|
|
||||||
|
def has_file(self, filename):
|
||||||
|
if not self.cache_dir:
|
||||||
|
return False
|
||||||
|
return os.path.exists(self._make_path(filename))
|
||||||
|
|
||||||
|
def put(self, data, filename, binary=True):
|
||||||
|
if not self.cache_dir:
|
||||||
|
return
|
||||||
|
assert self.lock_path is not None
|
||||||
|
filepath = self._make_path(filename)
|
||||||
|
with FileLock(self.lock_path):
|
||||||
|
# use tempfile to be robust against program interruptions
|
||||||
|
mode = "wb" if binary else "w"
|
||||||
|
with open(filepath + ".tmp", mode) as f:
|
||||||
|
f.write(data)
|
||||||
|
os.rename(filepath + ".tmp", filepath)
|
||||||
|
|
||||||
|
|
||||||
|
# utilties for generating and compiling C wrappers
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def libcuda_dir():
|
||||||
|
loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1]
|
||||||
|
return os.path.dirname(loc)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def quiet():
|
||||||
|
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||||
|
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||||
|
|
||||||
|
|
||||||
|
def _build(name, src, srcdir):
|
||||||
|
cuda_lib_dir = libcuda_dir()
|
||||||
|
cu_include_dir = "/usr/local/cuda/include"
|
||||||
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||||
|
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||||
|
# try to avoid setuptools if possible
|
||||||
|
cc = os.environ.get("CC")
|
||||||
|
if cc is None:
|
||||||
|
# TODO: support more things here.
|
||||||
|
clang = shutil.which("clang")
|
||||||
|
gcc = shutil.which("gcc")
|
||||||
|
cc = gcc if gcc is not None else clang
|
||||||
|
py_include_dir = get_paths()["include"]
|
||||||
|
ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so])
|
||||||
|
if ret == 0:
|
||||||
|
return so
|
||||||
|
# fallback on setuptools
|
||||||
|
extra_compile_args = []
|
||||||
|
library_dirs = [cuda_lib_dir]
|
||||||
|
include_dirs = [srcdir, cu_include_dir]
|
||||||
|
libraries = ['cuda']
|
||||||
|
# extra arguments
|
||||||
|
extra_link_args = []
|
||||||
|
# create extension module
|
||||||
|
ext = setuptools.Extension(
|
||||||
|
name=name,
|
||||||
|
language='c',
|
||||||
|
sources=[src],
|
||||||
|
include_dirs=include_dirs,
|
||||||
|
extra_compile_args=extra_compile_args + ['-O3'],
|
||||||
|
extra_link_args=extra_link_args,
|
||||||
|
library_dirs=library_dirs,
|
||||||
|
libraries=libraries,
|
||||||
|
)
|
||||||
|
# build extension module
|
||||||
|
args = ['build_ext']
|
||||||
|
args.append('--build-temp=' + srcdir)
|
||||||
|
args.append('--build-lib=' + srcdir)
|
||||||
|
args.append('-q')
|
||||||
|
args = dict(
|
||||||
|
name=name,
|
||||||
|
ext_modules=[ext],
|
||||||
|
script_args=args,
|
||||||
|
)
|
||||||
|
with quiet():
|
||||||
|
setuptools.setup(**args)
|
||||||
|
return so
|
||||||
|
|
||||||
|
|
||||||
|
def make_so_cache_key(signature, constants):
|
||||||
|
# Get unique key for the compiled code
|
||||||
|
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
||||||
|
key = f"{''.join(signature.values())}{constants}"
|
||||||
|
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages):
|
||||||
|
# Get unique key for the compiled code
|
||||||
|
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||||
|
configs_key = [get_conf_key(conf) for conf in configs]
|
||||||
|
key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
|
||||||
|
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||||
|
if isinstance(signature, str):
|
||||||
|
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||||
|
# we get the kernel, i.e. the first function generated in the module
|
||||||
|
if configs is None:
|
||||||
|
configs = [instance_descriptor()]
|
||||||
|
assert len(configs) == 1
|
||||||
|
# cache manager
|
||||||
|
name = fn.__name__
|
||||||
|
# name of files that are cached
|
||||||
|
so_cache_key = make_so_cache_key(signature, constants)
|
||||||
|
so_cache_manager = CacheManager(so_cache_key)
|
||||||
|
so_name = f"{name}.so"
|
||||||
|
# retrieve stub from cache if it exists
|
||||||
|
if not so_cache_manager.has_file(so_name):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
src = generate_launcher(name, constants, signature)
|
||||||
|
src_path = os.path.join(tmpdir, "main.c")
|
||||||
|
with open(src_path, "w") as f:
|
||||||
|
f.write(src)
|
||||||
|
so = _build(fn.__name__, src_path, tmpdir)
|
||||||
|
with open(so, "rb") as f:
|
||||||
|
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||||
|
|
||||||
|
# retrieve cached shared object if it exists
|
||||||
|
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||||
|
fn_cache_manager = CacheManager(fn_cache_key)
|
||||||
|
ptx_name = f"{name}.ptx"
|
||||||
|
cubin_name = f"{name}.cubin"
|
||||||
|
data_name = f"{name}.json"
|
||||||
|
if not fn_cache_manager.has_file(cubin_name) or \
|
||||||
|
not fn_cache_manager.has_file(data_name) or \
|
||||||
|
not fn_cache_manager.has_file(ptx_name):
|
||||||
|
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
||||||
|
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||||
|
fn_cache_manager.put(cubin, cubin_name)
|
||||||
|
fn_cache_manager.put(ptx, ptx_name, binary=False)
|
||||||
|
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||||
|
|
||||||
|
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class CompiledKernel:
|
||||||
|
|
||||||
|
def __init__(self, fn_name, so_path, cache_dir):
|
||||||
|
|
||||||
|
# initialize launcher
|
||||||
|
import importlib.util
|
||||||
|
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
self.c_wrapper = getattr(mod, "launch")
|
||||||
|
# initialize metadata
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.shared = metadata["shared"]
|
||||||
|
self.num_warps = metadata["num_warps"]
|
||||||
|
self.num_stages = metadata["num_stages"]
|
||||||
|
# initialize asm dict
|
||||||
|
self.asm = dict()
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||||
|
self.asm["cubin"] = f.read()
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||||
|
self.asm["ptx"] = f.read()
|
||||||
|
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
|
self.cu_module = mod
|
||||||
|
self.cu_function = func
|
||||||
|
|
||||||
|
def __getitem__(self, grid):
|
||||||
|
def runner(*args, stream=None):
|
||||||
|
if stream is None:
|
||||||
|
stream = torch.cuda.current_stream().cuda_stream
|
||||||
|
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||||
|
return
|
||||||
|
@@ -1,2 +1,2 @@
|
|||||||
from .autotuner import Config, autotune, heuristics # noqa: F401
|
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
|
||||||
from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401
|
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
|
||||||
|
@@ -5,10 +5,11 @@ import time
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from ..testing import do_bench
|
from ..testing import do_bench
|
||||||
|
from .jit import KernelInterface
|
||||||
|
|
||||||
|
|
||||||
class Autotuner:
|
class Autotuner(KernelInterface):
|
||||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||||
'''
|
'''
|
||||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
@@ -21,7 +22,6 @@ class Autotuner:
|
|||||||
self.configs = configs
|
self.configs = configs
|
||||||
self.key_idx = [arg_names.index(k) for k in key]
|
self.key_idx = [arg_names.index(k) for k in key]
|
||||||
self.cache = dict()
|
self.cache = dict()
|
||||||
self.kernel = kernel
|
|
||||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||||
self.hook = lambda args: 0
|
self.hook = lambda args: 0
|
||||||
if reset_to_zero is not None:
|
if reset_to_zero is not None:
|
||||||
@@ -41,6 +41,7 @@ class Autotuner:
|
|||||||
perf_model, top_k, early_config_prune = None, None, None
|
perf_model, top_k, early_config_prune = None, None, None
|
||||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||||
self.early_config_prune = early_config_prune
|
self.early_config_prune = early_config_prune
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
def _bench(self, *args, config, **meta):
|
def _bench(self, *args, config, **meta):
|
||||||
# check for conflicts, i.e. meta-parameters both provided
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
@@ -58,25 +59,16 @@ class Autotuner:
|
|||||||
if config.pre_hook:
|
if config.pre_hook:
|
||||||
config.pre_hook(self.nargs)
|
config.pre_hook(self.nargs)
|
||||||
self.hook(args)
|
self.hook(args)
|
||||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||||
return do_bench(kernel_call)
|
return do_bench(kernel_call)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
self.nargs = dict(zip(self.arg_names, args))
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
if len(self.configs) > 1:
|
if len(self.configs) > 1:
|
||||||
key = tuple([args[i] for i in self.key_idx])
|
key = tuple([args[i] for i in self.key_idx])
|
||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
# prune configs
|
# prune configs
|
||||||
pruned_configs = self.configs
|
pruned_configs = self.prune_configs(kwargs)
|
||||||
if self.early_config_prune:
|
|
||||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
|
||||||
if self.perf_model:
|
|
||||||
top_k = self.configs_top_k
|
|
||||||
if isinstance(top_k, float) and top_k <= 1.0:
|
|
||||||
top_k = int(len(self.configs) * top_k)
|
|
||||||
if len(pruned_configs) > top_k:
|
|
||||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
|
||||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
|
||||||
bench_start = time.time()
|
bench_start = time.time()
|
||||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||||
for config in pruned_configs}
|
for config in pruned_configs}
|
||||||
@@ -91,13 +83,41 @@ class Autotuner:
|
|||||||
self.best_config = config
|
self.best_config = config
|
||||||
if config.pre_hook is not None:
|
if config.pre_hook is not None:
|
||||||
config.pre_hook(self.nargs)
|
config.pre_hook(self.nargs)
|
||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
def prune_configs(self, kwargs):
|
||||||
|
pruned_configs = self.configs
|
||||||
|
if self.early_config_prune:
|
||||||
|
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||||
|
if self.perf_model:
|
||||||
|
top_k = self.configs_top_k
|
||||||
|
if isinstance(top_k, float) and top_k <= 1.0:
|
||||||
|
top_k = int(len(self.configs) * top_k)
|
||||||
|
if len(pruned_configs) > top_k:
|
||||||
|
est_timing = {
|
||||||
|
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||||
|
num_warps=config.num_warps)
|
||||||
|
for config in pruned_configs
|
||||||
|
}
|
||||||
|
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||||
|
return pruned_configs
|
||||||
|
|
||||||
|
def warmup(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
|
for config in self.prune_configs(kwargs):
|
||||||
|
self.fn.warmup(
|
||||||
|
*args,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
**kwargs,
|
||||||
|
**config.kwargs,
|
||||||
|
)
|
||||||
|
self.nargs = None
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""
|
"""
|
||||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||||
|
|
||||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||||
:type meta: dict[Str, Any]
|
:type meta: dict[Str, Any]
|
||||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||||
@@ -129,10 +149,8 @@ class Config:
|
|||||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||||
"""
|
"""
|
||||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||||
|
|
||||||
.. highlight:: python
|
.. highlight:: python
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@triton.autotune(configs=[
|
@triton.autotune(configs=[
|
||||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||||
@@ -143,12 +161,10 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(x_ptr, x_size, **META):
|
def kernel(x_ptr, x_size, **META):
|
||||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||||
|
|
||||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||||
This means that whatever value the kernel updates will be updated multiple times.
|
This means that whatever value the kernel updates will be updated multiple times.
|
||||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||||
reset the value of the provided tensor to `zero` before running any configuration.
|
reset the value of the provided tensor to `zero` before running any configuration.
|
||||||
|
|
||||||
:param configs: a list of :code:`triton.Config` objects
|
:param configs: a list of :code:`triton.Config` objects
|
||||||
:type configs: list[triton.Config]
|
:type configs: list[triton.Config]
|
||||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||||
@@ -161,43 +177,39 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
|||||||
:type reset_to_zero: list[str]
|
:type reset_to_zero: list[str]
|
||||||
"""
|
"""
|
||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
def wrapper(kernel):
|
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
|
||||||
|
|
||||||
fn.kernel_decorators.append(wrapper)
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Heuristics(KernelInterface):
|
||||||
|
|
||||||
|
def __init__(self, fn, arg_names, values) -> None:
|
||||||
|
self.fn = fn
|
||||||
|
self.values = values
|
||||||
|
self.arg_names = arg_names
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
for v, heur in self.values.items():
|
||||||
|
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
|
||||||
|
return self.fn.run(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def heuristics(values):
|
def heuristics(values):
|
||||||
"""
|
"""
|
||||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||||
|
|
||||||
.. highlight:: python
|
.. highlight:: python
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(x_ptr, x_size, **META):
|
def kernel(x_ptr, x_size, **META):
|
||||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||||
|
|
||||||
|
|
||||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||||
each such function takes a list of positional arguments as input.
|
each such function takes a list of positional arguments as input.
|
||||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||||
"""
|
"""
|
||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
def wrapper(kernel):
|
return Heuristics(fn, fn.arg_names, values)
|
||||||
def fun(*args, **meta):
|
|
||||||
for v, heur in values.items():
|
|
||||||
assert v not in meta
|
|
||||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
|
||||||
return kernel(*args, **meta)
|
|
||||||
return fun
|
|
||||||
|
|
||||||
fn.kernel_decorators.append(wrapper)
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations, division
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import functools
|
import functools
|
||||||
@@ -6,90 +6,24 @@ import hashlib
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Dict, List, Optional
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
from triton.utils import MockTensor
|
||||||
from ..compiler import compile
|
|
||||||
from ..tools.disasm import extract
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||||
except ImportError:
|
except ImportError:
|
||||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Binary
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
VALID_BACKENDS: List[str] = (
|
|
||||||
_triton.runtime.backend.CUDA,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Binary:
|
|
||||||
def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int):
|
|
||||||
assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend)
|
|
||||||
self.backend = backend
|
|
||||||
self.name = name
|
|
||||||
self.asm = asm
|
|
||||||
self.shared_mem = shared_mem
|
|
||||||
self.num_warps = num_warps
|
|
||||||
|
|
||||||
|
|
||||||
class LoadedBinary:
|
|
||||||
def __init__(self, device: int, bin: Binary):
|
|
||||||
module, kernel = _triton.load_binary(bin.backend,
|
|
||||||
bin.name,
|
|
||||||
bin.asm,
|
|
||||||
bin.shared_mem,
|
|
||||||
device)
|
|
||||||
self.bin = bin
|
|
||||||
self.asm = bin.asm
|
|
||||||
self.sass = ''
|
|
||||||
self.module = module
|
|
||||||
self.kernel = kernel
|
|
||||||
self.device = device
|
|
||||||
self.shared_mem = bin.shared_mem
|
|
||||||
|
|
||||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
|
||||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
|
||||||
grid_0, grid_1, grid_2,
|
|
||||||
self.bin.num_warps * 32, 1, 1,
|
|
||||||
args, self.bin.shared_mem)
|
|
||||||
|
|
||||||
def get_sass(self, fun=None):
|
|
||||||
if self.sass:
|
|
||||||
return self.sass
|
|
||||||
fd, path = tempfile.mkstemp()
|
|
||||||
try:
|
|
||||||
with open(fd, 'wb') as cubin:
|
|
||||||
cubin.write(self.asm['cubin'])
|
|
||||||
self.sass = extract(path, fun)
|
|
||||||
finally:
|
|
||||||
os.remove(path)
|
|
||||||
self.asm['sass'] = self.sass
|
|
||||||
return self.sass
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Kernel
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class Kernel:
|
|
||||||
|
|
||||||
def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
|
|
||||||
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Dependencies Finder
|
# Dependencies Finder
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class DependenciesFinder(ast.NodeVisitor):
|
class DependenciesFinder(ast.NodeVisitor):
|
||||||
"""
|
"""
|
||||||
This AST visitor is used to find dependencies of a JITFunction. This can
|
This AST visitor is used to find dependencies of a JITFunction. This can
|
||||||
@@ -142,6 +76,8 @@ def version_key():
|
|||||||
# frontend
|
# frontend
|
||||||
with open(__file__, "rb") as f:
|
with open(__file__, "rb") as f:
|
||||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
|
with open(triton.compiler.__file__, "rb") as f:
|
||||||
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
# backend
|
# backend
|
||||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
@@ -158,26 +94,213 @@ def version_key():
|
|||||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||||
|
|
||||||
|
|
||||||
class JITFunction:
|
class KernelInterface:
|
||||||
|
|
||||||
|
def __getitem__(self, grid):
|
||||||
|
"""
|
||||||
|
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||||
|
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||||
|
memorizes the grid.
|
||||||
|
"""
|
||||||
|
def launcher(*args, **kwargs):
|
||||||
|
return self.run(*args, grid=grid, **kwargs)
|
||||||
|
return launcher
|
||||||
|
|
||||||
|
|
||||||
|
class JITFunction(KernelInterface):
|
||||||
|
|
||||||
cache_hook = None
|
cache_hook = None
|
||||||
|
divisibility = 16
|
||||||
|
|
||||||
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
|
@staticmethod
|
||||||
# information of wrapped function
|
def _key_of(arg):
|
||||||
|
if hasattr(arg, "dtype"):
|
||||||
|
return arg.dtype
|
||||||
|
elif isinstance(arg, bool):
|
||||||
|
return "i1"
|
||||||
|
elif isinstance(arg, int):
|
||||||
|
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||||
|
return "i32"
|
||||||
|
elif 2**31 <= arg and arg <= 2**32 - 1:
|
||||||
|
return "u32"
|
||||||
|
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||||
|
return "u64"
|
||||||
|
else:
|
||||||
|
return "i64"
|
||||||
|
elif isinstance(arg, float):
|
||||||
|
return 'fp32'
|
||||||
|
elif arg is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _spec_of(arg):
|
||||||
|
if hasattr(arg, "data_ptr"):
|
||||||
|
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
||||||
|
elif isinstance(arg, int):
|
||||||
|
return (arg % 16 == 0, arg == 1)
|
||||||
|
return (arg is None, )
|
||||||
|
|
||||||
|
def _get_config(self, *args):
|
||||||
|
def is_divisible_by_16(x):
|
||||||
|
if hasattr(x, "data_ptr"):
|
||||||
|
return x.data_ptr() % JITFunction.divisibility == 0
|
||||||
|
elif isinstance(x, int):
|
||||||
|
return x % JITFunction.divisibility == 0
|
||||||
|
if x is None:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||||
|
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||||
|
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||||
|
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _type_of(key):
|
||||||
|
if isinstance(key, (torch.dtype, triton.language.dtype)):
|
||||||
|
ty = {
|
||||||
|
torch.bool: 'i1',
|
||||||
|
torch.float16: 'fp16',
|
||||||
|
torch.bfloat16: 'bf16',
|
||||||
|
torch.float32: 'fp32',
|
||||||
|
torch.float64: 'fp64',
|
||||||
|
torch.uint8: 'u8',
|
||||||
|
torch.int8: 'i8',
|
||||||
|
torch.int16: 'i16',
|
||||||
|
torch.int32: 'i32',
|
||||||
|
torch.int64: 'i64',
|
||||||
|
|
||||||
|
triton.language.uint8: 'u8',
|
||||||
|
triton.language.uint16: 'u16',
|
||||||
|
triton.language.uint32: 'u32',
|
||||||
|
triton.language.uint64: 'u64',
|
||||||
|
triton.language.float8: 'fp8',
|
||||||
|
}[key]
|
||||||
|
return f'*{ty}'
|
||||||
|
if key is None:
|
||||||
|
return '*i8'
|
||||||
|
assert isinstance(key, str)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _make_signature(self, sig_key):
|
||||||
|
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
||||||
|
return signature
|
||||||
|
|
||||||
|
def _make_constants(self, constexpr_key):
|
||||||
|
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
||||||
|
return constants
|
||||||
|
|
||||||
|
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
|
if JITFunction.cache_hook is None:
|
||||||
|
return False
|
||||||
|
name = self.fn.__name__
|
||||||
|
module = self.fn.__module__
|
||||||
|
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||||
|
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
||||||
|
key = str(key)
|
||||||
|
|
||||||
|
class LegacyCompiler:
|
||||||
|
def __init__(self, module, name):
|
||||||
|
self.module = module
|
||||||
|
self.name = name
|
||||||
|
pass
|
||||||
|
|
||||||
|
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||||
|
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
||||||
|
configs=configs)
|
||||||
|
|
||||||
|
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
||||||
|
|
||||||
|
def _make_launcher(self):
|
||||||
|
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||||
|
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||||
|
args = ', '.join(regular_args)
|
||||||
|
# cache key for regular argument type
|
||||||
|
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
|
||||||
|
# cache key for constexpr argument values
|
||||||
|
constexpr_keys = ', '.join(constexpr_args)
|
||||||
|
# cache key for argument specialization
|
||||||
|
specializations = []
|
||||||
|
for i, arg in enumerate(regular_args):
|
||||||
|
if i in self.do_not_specialize:
|
||||||
|
continue
|
||||||
|
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
||||||
|
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
||||||
|
f'else (False,)']
|
||||||
|
spec_keys = ', '.join(specializations)
|
||||||
|
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||||
|
|
||||||
|
src = f"""
|
||||||
|
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
|
||||||
|
sig_key = {sig_keys},
|
||||||
|
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
|
||||||
|
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
|
||||||
|
key = (version_key, sig_key, constexpr_key, spec_key)
|
||||||
|
if not extern_libs is None:
|
||||||
|
key = (key, tuple(extern_libs.items()))
|
||||||
|
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||||
|
if callable(grid):
|
||||||
|
grid = grid({{{grid_args}}})
|
||||||
|
grid_size = len(grid)
|
||||||
|
grid_0 = grid[0]
|
||||||
|
grid_1 = grid[1] if grid_size > 1 else 1
|
||||||
|
grid_2 = grid[2] if grid_size > 2 else 1
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
if stream is None and not warmup:
|
||||||
|
stream = get_cuda_stream(device)
|
||||||
|
try:
|
||||||
|
bin = cache[key]
|
||||||
|
if not warmup:
|
||||||
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args})
|
||||||
|
return bin
|
||||||
|
# kernel not cached -- compile
|
||||||
|
except KeyError:
|
||||||
|
# build dict of constant values
|
||||||
|
args = [{args}]
|
||||||
|
configs = self._get_config(*args),
|
||||||
|
constants = self._make_constants(constexpr_key)
|
||||||
|
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
|
||||||
|
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
||||||
|
# build kernel signature -- doesn't include specialized arguments
|
||||||
|
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||||
|
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
||||||
|
# build stub signature -- includes arguments that are specialized
|
||||||
|
for i, arg in constants.items():
|
||||||
|
if callable(arg):
|
||||||
|
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||||
|
device = 0
|
||||||
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
|
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
|
if not warmup:
|
||||||
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
|
||||||
|
self.cache[key] = bin
|
||||||
|
return bin
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||||
|
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||||
|
"cache": self.cache, "triton": triton, "torch": torch}
|
||||||
|
exec(src, scope)
|
||||||
|
return scope[self.fn.__name__]
|
||||||
|
|
||||||
|
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.module = fn.__module__
|
self.module = fn.__module__
|
||||||
|
self.version = version
|
||||||
|
# function signature information
|
||||||
signature = inspect.signature(fn)
|
signature = inspect.signature(fn)
|
||||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
|
||||||
|
# specialization hints
|
||||||
self.version = version
|
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||||
self.inline = inline
|
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
|
||||||
|
# function source code (without decorators)
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
self.src = self.src[self.src.find("def"):]
|
self.src = self.src[self.src.find("def"):]
|
||||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
# cache of just-in-time compiled kernels
|
||||||
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
|
self.cache = dict()
|
||||||
# cache for callable driver objects (e.g. CUkernel)
|
|
||||||
self.bin_cache = dict()
|
|
||||||
self.hash = None
|
self.hash = None
|
||||||
# JITFunction can be instantiated as kernel
|
# JITFunction can be instantiated as kernel
|
||||||
# when called with a grid using __getitem__
|
# when called with a grid using __getitem__
|
||||||
@@ -186,16 +309,17 @@ class JITFunction:
|
|||||||
# annotations
|
# annotations
|
||||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||||
self.__annotations__ = fn.__annotations__
|
self.__annotations__ = fn.__annotations__
|
||||||
# constexprs
|
# index of constexprs
|
||||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||||
# forward docs
|
# launcher
|
||||||
|
self.run = self._make_launcher()
|
||||||
|
# re-use docs of wrapped function
|
||||||
self.__doc__ = fn.__doc__
|
self.__doc__ = fn.__doc__
|
||||||
self.__name__ = fn.__name__
|
self.__name__ = fn.__name__
|
||||||
self.__globals__ = fn.__globals__
|
self.__globals__ = fn.__globals__
|
||||||
self.__module__ = fn.__module__
|
self.__module__ = fn.__module__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@functools.lru_cache()
|
|
||||||
def cache_key(self):
|
def cache_key(self):
|
||||||
# TODO : hash should be attribute of `self`
|
# TODO : hash should be attribute of `self`
|
||||||
if self.hash is None:
|
if self.hash is None:
|
||||||
@@ -204,9 +328,12 @@ class JITFunction:
|
|||||||
self.hash = dependencies_finder.ret + version_key()
|
self.hash = dependencies_finder.ret + version_key()
|
||||||
return self.hash
|
return self.hash
|
||||||
|
|
||||||
|
def warmup(self, *args, **kwargs):
|
||||||
|
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
|
||||||
|
|
||||||
# we do not parse `src` in the constructor because
|
# we do not parse `src` in the constructor because
|
||||||
# the user might want to monkey-patch self.src dynamically.
|
# the user might want to monkey-patch self.src dynamically.
|
||||||
# Some unit tests do this, for example.
|
# Our unit tests do this, for example.
|
||||||
def parse(self):
|
def parse(self):
|
||||||
tree = ast.parse(self.src)
|
tree = ast.parse(self.src)
|
||||||
assert isinstance(tree, ast.Module)
|
assert isinstance(tree, ast.Module)
|
||||||
@@ -217,167 +344,21 @@ class JITFunction:
|
|||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||||
|
|
||||||
# - when `.src` attribute is set, cache path needs
|
def __setattr__(self, name, value):
|
||||||
# to be reinitialized
|
|
||||||
# - when kernel decorators change, cached kernel
|
# - when kernel decorators change, cached kernel
|
||||||
# needs to be cleared
|
# needs to be cleared
|
||||||
def __setattr__(self, name, value):
|
|
||||||
if name == 'kernel_decorators':
|
if name == 'kernel_decorators':
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
super(JITFunction, self).__setattr__(name, value)
|
super(JITFunction, self).__setattr__(name, value)
|
||||||
|
# - when `.src` attribute is set, cache path needs
|
||||||
|
# to be reinitialized
|
||||||
if name == 'src':
|
if name == 'src':
|
||||||
self.hash = None
|
self.hash = None
|
||||||
JITFunction.cache_key.fget.cache_clear()
|
|
||||||
|
|
||||||
def _init_kernel(self):
|
|
||||||
if self.kernel is None:
|
|
||||||
self.kernel = Kernel(self)
|
|
||||||
for decorator in reversed(self.kernel_decorators):
|
|
||||||
self.kernel = decorator(self.kernel)
|
|
||||||
return self.kernel
|
|
||||||
|
|
||||||
def __getitem__(self, grid):
|
|
||||||
"""
|
|
||||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
|
||||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
|
||||||
memorizes the grid.
|
|
||||||
"""
|
|
||||||
class Launcher:
|
|
||||||
def __init__(self, kernel, grid):
|
|
||||||
self.kernel = kernel
|
|
||||||
self.grid = grid
|
|
||||||
|
|
||||||
def __call__(self, *wargs, **kwargs):
|
|
||||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
|
||||||
|
|
||||||
return Launcher(self._init_kernel(), grid)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||||
|
|
||||||
|
|
||||||
def pow2_divisor(N):
|
|
||||||
if N % 16 == 0:
|
|
||||||
return 16
|
|
||||||
if N % 8 == 0:
|
|
||||||
return 8
|
|
||||||
if N % 4 == 0:
|
|
||||||
return 4
|
|
||||||
if N % 2 == 0:
|
|
||||||
return 2
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
class _KernelCache:
|
|
||||||
def __init__(self,
|
|
||||||
fn: JITFunction,
|
|
||||||
fn_type: str,
|
|
||||||
constants: Dict[str, Any],
|
|
||||||
num_warps: int = 4,
|
|
||||||
num_stages: int = 3):
|
|
||||||
# hold the arguments for building a kernel
|
|
||||||
self.fn = fn
|
|
||||||
self.fn_type = fn_type
|
|
||||||
self.constants = constants
|
|
||||||
self.num_warps = num_warps
|
|
||||||
self.num_stages = num_stages
|
|
||||||
|
|
||||||
# kernel compilation cache
|
|
||||||
self._binary_cache: Optional[LoadedBinary] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def binary_cache(self):
|
|
||||||
return self._binary_cache
|
|
||||||
|
|
||||||
def set_binary_cache(self, binary: LoadedBinary):
|
|
||||||
assert binary
|
|
||||||
assert not self._binary_cache, "cannot set binary cache duplicately"
|
|
||||||
self._binary_cache = binary
|
|
||||||
|
|
||||||
|
|
||||||
def build_kernel(fn: JITFunction,
|
|
||||||
fn_type: str,
|
|
||||||
constants: Dict[str, Any],
|
|
||||||
num_warps: int = 4,
|
|
||||||
num_stages: int = 3,
|
|
||||||
) -> _KernelCache:
|
|
||||||
return _KernelCache(fn, fn_type, constants, num_warps, num_stages)
|
|
||||||
|
|
||||||
|
|
||||||
torch_dtype_to_bytes = {
|
|
||||||
torch.int8: 1,
|
|
||||||
torch.uint8: 1,
|
|
||||||
|
|
||||||
torch.int16: 2,
|
|
||||||
torch.short: 2,
|
|
||||||
|
|
||||||
torch.int: 4,
|
|
||||||
torch.int32: 4,
|
|
||||||
|
|
||||||
torch.long: 8,
|
|
||||||
torch.int64: 8,
|
|
||||||
|
|
||||||
torch.float32: 4,
|
|
||||||
torch.float: 4,
|
|
||||||
|
|
||||||
torch.float16: 2,
|
|
||||||
torch.half: 2,
|
|
||||||
torch.bfloat16: 2,
|
|
||||||
# free to extend
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
|
|
||||||
def is_tensor(arg):
|
|
||||||
return hasattr(arg, 'data_ptr') # a torch.tensor
|
|
||||||
|
|
||||||
# prepare function args for compile
|
|
||||||
kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
|
||||||
wargs = list(wargs)
|
|
||||||
for i, pos in enumerate(sorted(kwargs)):
|
|
||||||
wargs.insert(pos + i, kwargs[pos])
|
|
||||||
assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs))
|
|
||||||
|
|
||||||
if not kernel.binary_cache:
|
|
||||||
# build the kernel cache
|
|
||||||
backend = _triton.runtime.backend.CUDA
|
|
||||||
|
|
||||||
attributes = dict()
|
|
||||||
for i, arg in enumerate(wargs):
|
|
||||||
if i in kernel.fn.do_not_specialize:
|
|
||||||
continue
|
|
||||||
if isinstance(arg, int):
|
|
||||||
attributes[i] = pow2_divisor(arg)
|
|
||||||
elif is_tensor(arg):
|
|
||||||
assert arg.dtype in torch_dtype_to_bytes
|
|
||||||
addr = arg.data_ptr()
|
|
||||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
|
||||||
divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype]
|
|
||||||
attributes[i] = divisibility
|
|
||||||
|
|
||||||
attributes_ = dict()
|
|
||||||
for i, value in attributes.items():
|
|
||||||
attributes_[kernel.fn.arg_names[i]] = value
|
|
||||||
|
|
||||||
cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin")
|
|
||||||
assert cubin
|
|
||||||
assert kernel_name
|
|
||||||
|
|
||||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
|
||||||
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
|
|
||||||
|
|
||||||
asm = dict(cubin=cubin)
|
|
||||||
binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps)
|
|
||||||
loaded_binary = LoadedBinary(device, binary)
|
|
||||||
kernel.set_binary_cache(loaded_binary)
|
|
||||||
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
stream = get_cuda_stream(device)
|
|
||||||
|
|
||||||
_triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names,
|
|
||||||
stream, kernel.num_warps, kernel.num_stages, grid)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# `jit` decorator
|
# `jit` decorator
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -386,16 +367,12 @@ def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
|
|||||||
def jit(*args, **kwargs):
|
def jit(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Decorator for JIT-compiling a function using the Triton compiler.
|
Decorator for JIT-compiling a function using the Triton compiler.
|
||||||
|
|
||||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||||
|
|
||||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||||
|
|
||||||
* python primitives,
|
* python primitives,
|
||||||
* objects within the triton.language package,
|
* objects within the triton.language package,
|
||||||
* arguments to this function,
|
* arguments to this function,
|
||||||
* other jit'd functions
|
* other jit'd functions
|
||||||
|
|
||||||
:param fn: the function to be jit-compiled
|
:param fn: the function to be jit-compiled
|
||||||
:type fn: Callable
|
:type fn: Callable
|
||||||
"""
|
"""
|
||||||
@@ -407,3 +384,32 @@ def jit(*args, **kwargs):
|
|||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
return JITFunction(fn, **kwargs)
|
return JITFunction(fn, **kwargs)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class TensorWrapper:
|
||||||
|
def __init__(self, base, dtype):
|
||||||
|
self.dtype = dtype
|
||||||
|
self.base = base
|
||||||
|
self.is_cuda = base.is_cuda
|
||||||
|
self.device = base.device
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
return self.base.data_ptr()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||||
|
|
||||||
|
|
||||||
|
def reinterpret(tensor, dtype):
|
||||||
|
if isinstance(tensor, TensorWrapper):
|
||||||
|
if dtype == tensor.base.dtype:
|
||||||
|
# Reinterpreting to the original interpretation; return the base.
|
||||||
|
return tensor.base
|
||||||
|
else:
|
||||||
|
# Reinterpreting a wrapped tensor to a different type.
|
||||||
|
return TensorWrapper(tensor.base, dtype)
|
||||||
|
elif isinstance(tensor, torch.Tensor):
|
||||||
|
# A new wrapper is needed around an unwrapped tensor.
|
||||||
|
return TensorWrapper(tensor, dtype)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||||
|
@@ -19,6 +19,24 @@ def next_power_of_2(n):
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
class MockTensor:
|
||||||
|
"""
|
||||||
|
Can be used in place of real tensors when calling:
|
||||||
|
kernel.warmup(MockTensor(torch.float32), ...)
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def wrap_dtype(arg):
|
||||||
|
if isinstance(arg, torch.dtype):
|
||||||
|
return MockTensor(arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
|
def __init__(self, dtype):
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
return 0 # optimistically assumes multiple of 16
|
||||||
|
|
||||||
|
|
||||||
class TensorWrapper:
|
class TensorWrapper:
|
||||||
def __init__(self, base, dtype):
|
def __init__(self, base, dtype):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
Reference in New Issue
Block a user