[FRONTEND] add python e2e launch empty kernel test (#68)

This commit is contained in:
Yan Chunwei
2022-08-20 01:46:01 +08:00
committed by GitHub
parent 9aa00249a6
commit 10ba51c3bb
11 changed files with 311 additions and 69 deletions

View File

@@ -19,7 +19,9 @@ Type i8Ty(MLIRContext *ctx) {
Type u32Ty(MLIRContext *ctx) { Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signless); return IntegerType::get(ctx, 32, IntegerType::Signless);
} }
Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1); } Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
}
// Float types // Float types
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }

View File

@@ -2,8 +2,6 @@
#include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>

View File

@@ -171,6 +171,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
return failure(); return failure();
auto ctx = funcOp->getContext(); auto ctx = funcOp->getContext();
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr(NVVMMetadataField::Kernel,
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
// Set an attribute for maxntidx, it could be used in latter LLVM codegen // Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata. // for `nvvm.annotation` metadata.
newFuncOp->setAttr( newFuncOp->setAttr(

View File

@@ -6,9 +6,11 @@
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h" #include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
@@ -254,19 +256,105 @@ void parse_args(py::list &args, py::list do_not_specialize,
params_size = (std::ptrdiff_t)(params_ptr - &params[0]); params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
} }
// void parse_args(py::list &args, py::list &arg_names, std::string &params,
size_t &params_size, py::dict constants) {
char *params_ptr = params.data();
size_t len = PyList_Size(args.ptr());
for (int i = 0; i < len; i++) {
py::object arg = args[i];
auto arg_ptr = arg.ptr();
if (PyLong_Check(arg_ptr)) {
int overflow{};
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow && 0x8000'0000LL <= value &&
value <= 0xFFFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " +
std::string(py::str(arg)));
}
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
continue;
}
if (PyFloat_Check(arg_ptr)) {
float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
continue;
}
// argument is `bool`
if (PyBool_Check(arg_ptr)) {
bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1);
params_ptr += 1;
continue;
}
// argument is torch.tensor, get data_ptr as memory address
if (py::hasattr(arg, "data_ptr")) {
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// udpate cache key
continue;
}
// argument is `constexpr`
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
continue;
}
// argument is `LoadedBinary`
if (py::hasattr(arg, "get_sass")) {
// Do nothing, just a placeholder here to indicate validity.
continue;
}
std::string ty_str =
arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " +
std::to_string(i) + "." +
" Only int, float, bool, torch.Tensor, and "
"triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
void init_triton_runtime(py::module &&m) { void init_triton_runtime(py::module &&m) {
// m.def("current_stream", [](uint64_t device){
// return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
// });
// wrap backend_t // wrap backend_t
py::enum_<backend_t>(m, "backend") py::enum_<backend_t>(m, "backend")
.value("HOST", HOST) .value("HOST", HOST)
.value("CUDA", CUDA) .value("CUDA", CUDA)
.value("ROCM", ROCM) // .value("ROCM", ROCM)
.export_values(); .export_values();
// enable peer-to-peer // enable peer-to-peer
@@ -347,6 +435,49 @@ void init_triton_runtime(py::module &&m) {
return -1; return -1;
}); });
m.def("launch_binary", [](py::object binary, py::list args,
py::list do_not_specialize, py::list arg_names,
py::int_ stream, py::int_ num_warps,
py::int_ num_stages, py::object grid) {
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
// get grid
py::sequence seq;
py::dict constants;
std::string params;
size_t params_size{};
parse_args(args, arg_names, params, params_size, constants);
if (!PySequence_Check(grid.ptr()))
seq = grid(constants);
else
seq = grid;
int size = seq.size();
int grid_0 = py::cast<int>(seq[0]);
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
// actually launch
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END};
uint64_t _stream = PyLong_AsLong(stream.ptr());
const int numGrids = grid_0 * grid_1 * grid_2;
if (numGrids) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps * 32, 1, 1, shared_mem,
(CUstream)_stream, nullptr, config);
}
return binary;
});
// query maximum shared memory // query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) { m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST) if (backend == HOST)
@@ -1517,7 +1648,7 @@ void init_triton_ir(py::module &&m) {
}); });
} }
void init_translation(py::module &m) { void init_triton_translation(py::module &m) {
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string { m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
llvm::LLVMContext llvmContext; llvm::LLVMContext llvmContext;
auto llvmModule = auto llvmModule =
@@ -1531,10 +1662,16 @@ void init_translation(py::module &m) {
}); });
m.def("translate_triton_gpu_to_ptx", m.def("translate_triton_gpu_to_ptx",
[](mlir::ModuleOp module, uint64_t device) -> std::string { [](mlir::ModuleOp module, uint64_t device)
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
auto [ptxCode, cc, version, ptxasPath] = auto [ptxCode, cc, version, ptxasPath] =
triton::translateTritonGPUToPTX(module, device); triton::translateTritonGPUToPTX(module, device);
return ptxCode;
mlir::PassManager pm(module->getContext());
auto pass = std::make_unique<mlir::Allocation>(module);
size_t size = pass->getSharedMemorySize();
return std::make_tuple(ptxCode, size);
}); });
m.def("compile_ptx_to_cubin", m.def("compile_ptx_to_cubin",
@@ -1550,6 +1687,16 @@ void init_translation(py::module &m) {
py::bytes bytes(cubin); py::bytes bytes(cubin);
return bytes; return bytes;
}); });
m.def(
"load_binary",
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
size_t n_shared_bytes, uint64_t dev) {
py::gil_scoped_release allow_threads;
assert(backend == CUDA); // Only CUDA is supported now.
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
},
py::return_value_policy::take_ownership);
} }
void init_triton(py::module &m) { void init_triton(py::module &m) {
@@ -1557,5 +1704,5 @@ void init_triton(py::module &m) {
// init_triton_codegen(std::move(subm.def_submodule("code_gen"))); // init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir"))); init_triton_ir(std::move(subm.def_submodule("ir")));
init_translation(subm); init_triton_translation(subm);
} }

View File

@@ -1,30 +0,0 @@
import triton
import triton.language as tl
NUM_WARPS = 4
# triton kernel
@triton.jit
def kernel(x_ptr, stride_xn,
y_ptr, stride_yn,
z_ptr, stride_zn,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
print(ret)
# TODO: base class for python end2end tests,
# runtime execution, correctness comparison etc.

View File

@@ -2,18 +2,21 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
import triton.runtime as runtime
# trigger the torch.device implicitly to ensure cuda context initialization # trigger the torch.device implicitly to ensure cuda context initialization
torch.zeros([10], device=torch.device('cuda')) torch.zeros([10], device=torch.device('cuda'))
@triton.jit
def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
pass
def test_empty_kernel_cubin_compile(): def test_empty_kernel_cubin_compile():
@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
device = torch.cuda.current_device() device = torch.cuda.current_device()
cubin = triton.compile(kernel, cubin = triton.compile(empty_kernel,
"*fp32,i32,i32", "*fp32,i32,i32",
device=device, device=device,
constants={"BLOCK": 256}, constants={"BLOCK": 256},
@@ -21,3 +24,25 @@ def test_empty_kernel_cubin_compile():
print('cubin size:', len(cubin)) print('cubin size:', len(cubin))
assert len(cubin) > 0 assert len(cubin) > 0
def test_empty_kernel_launch():
device = torch.cuda.current_device()
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
device=device,
constants={"BLOCK": 256},
num_warps=4,
num_stages=3)
grid = lambda META: (
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
)
A = torch.zeros([1024], device="cuda")
runtime.launch_kernel(fn=empty_kernel,
binary=binary,
grid=grid,
num_warps=4,
num_stages=3,
X=A,
stride_xm=256,
BLOCK=tl.constexpr(256))

View File

@@ -0,0 +1,27 @@
import triton
import triton.language as tl
NUM_WARPS = 4
# triton kernel
def test_vecadd_no_scf():
@triton.jit
def kernel(x_ptr, stride_xn,
y_ptr, stride_yn,
z_ptr, stride_zn,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
print(ret)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import ast import ast
import sys import sys
import warnings import warnings
from typing import Dict, Union from typing import Any, Dict, Tuple, Union
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
@@ -800,11 +800,13 @@ def optimize_tritongpu_ir(mod, num_stages):
return mod return mod
def make_ptx(mod, device): def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
''' '''
Translate TritonGPU module to PTX code. Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module :param mod: a TritonGPU dialect module
:return: str :return:
- PTX code
- shared memory alloaction size
''' '''
return _triton.translate_triton_gpu_to_ptx(mod, device) return _triton.translate_triton_gpu_to_ptx(mod, device)
@@ -819,7 +821,20 @@ def make_cubin(ptx, device):
return _triton.compile_ptx_to_cubin(ptx, device) return _triton.compile_ptx_to_cubin(ptx, device)
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"): def ptx_get_kernel_name(ptx: str) -> str:
'''
Get kernel name from PTX code.
This Kernel name is required when launching the kernel.
'''
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
assert ptx
for line in ptx.split('\n'):
line = line.strip()
if line.startswith('// .globl'):
return line.split()[-1]
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
valid_outputs = ("ttir", "ttgir", "ptx", "cubin") valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir # triton-ir
@@ -830,17 +845,18 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w
# tritongpu-ir # tritongpu-ir
module = make_tritongpu_ir(module, num_warps) module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages) module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir": if output == "ttgir":
return module.str() return module.str()
assert device >= 0, "device should be provided." assert device >= 0, "device should be provided."
ptx, shem_size = make_ptx(module, device)
ptx = make_ptx(module, device) kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx": if output == "ptx":
return ptx return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, device) cubin = make_cubin(ptx, device)
if output == "cubin": if output == "cubin":
return cubin return cubin, shem_size, kernel_name
assert False assert False

View File

@@ -1,2 +1,2 @@
from .autotuner import Config, autotune, heuristics # noqa: F401 from .autotuner import Config, autotune, heuristics # noqa: F401
from .jit import JITFunction, jit # noqa: F401 from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401

View File

@@ -8,18 +8,32 @@ import os
import subprocess import subprocess
import tempfile import tempfile
import textwrap import textwrap
from typing import Any, Dict, List
import torch
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from ..compiler import compile
from ..tools.disasm import extract from ..tools.disasm import extract
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Binary # Binary
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
VALID_BACKENDS: List[str] = (
_triton.runtime.backend.CUDA,
)
class Binary: class Binary:
def __init__(self, backend, name, asm, shared_mem, num_warps): def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int):
assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend)
self.backend = backend self.backend = backend
self.name = name self.name = name
self.asm = asm self.asm = asm
@@ -29,11 +43,11 @@ class Binary:
class LoadedBinary: class LoadedBinary:
def __init__(self, device: int, bin: Binary): def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend, module, kernel = _triton.load_binary(bin.backend,
bin.name, bin.name,
bin.asm, bin.asm,
bin.shared_mem, bin.shared_mem,
device) device)
self.bin = bin self.bin = bin
self.asm = bin.asm self.asm = bin.asm
self.sass = '' self.sass = ''
@@ -241,6 +255,44 @@ class JITFunction:
def __repr__(self): def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})" return f"JITFunction({self.module}:{self.fn.__name__})"
def build_kernel(fn: JITFunction,
fn_type: str,
device: int,
constants: Dict[str, Any],
num_warps: int = 4,
num_stages: int = 3,
) -> LoadedBinary:
cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin")
assert cubin
assert kernel_name
backend = _triton.runtime.backend.CUDA
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
asm = dict(cubin=cubin)
binary = Binary(backend, kernel_name, asm, shem_size, num_warps)
loaded_binary = LoadedBinary(device, binary)
return loaded_binary
def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs):
kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs))
device = torch.cuda.current_device()
torch.cuda.set_device(device)
stream = get_cuda_stream(device)
_triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names,
stream, num_warps, num_stages, grid)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# `jit` decorator # `jit` decorator
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>) // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32 // Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}} // CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) { func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return // CHECK: llvm.return