[FRONTEND] add python e2e launch empty kernel test (#68)

This commit is contained in:
Yan Chunwei
2022-08-20 01:46:01 +08:00
committed by GitHub
parent 9aa00249a6
commit 10ba51c3bb
11 changed files with 311 additions and 69 deletions

View File

@@ -19,7 +19,9 @@ Type i8Ty(MLIRContext *ctx) {
Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Signless);
}
Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1); }
Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
}
// Float types
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }

View File

@@ -2,8 +2,6 @@
#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>

View File

@@ -171,6 +171,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
return failure();
auto ctx = funcOp->getContext();
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr(NVVMMetadataField::Kernel,
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr(

View File

@@ -6,9 +6,11 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -254,19 +256,105 @@ void parse_args(py::list &args, py::list do_not_specialize,
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
//
void parse_args(py::list &args, py::list &arg_names, std::string &params,
size_t &params_size, py::dict constants) {
char *params_ptr = params.data();
size_t len = PyList_Size(args.ptr());
for (int i = 0; i < len; i++) {
py::object arg = args[i];
auto arg_ptr = arg.ptr();
if (PyLong_Check(arg_ptr)) {
int overflow{};
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow && 0x8000'0000LL <= value &&
value <= 0xFFFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " +
std::string(py::str(arg)));
}
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
continue;
}
if (PyFloat_Check(arg_ptr)) {
float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
continue;
}
// argument is `bool`
if (PyBool_Check(arg_ptr)) {
bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1);
params_ptr += 1;
continue;
}
// argument is torch.tensor, get data_ptr as memory address
if (py::hasattr(arg, "data_ptr")) {
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// udpate cache key
continue;
}
// argument is `constexpr`
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
continue;
}
// argument is `LoadedBinary`
if (py::hasattr(arg, "get_sass")) {
// Do nothing, just a placeholder here to indicate validity.
continue;
}
std::string ty_str =
arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " +
std::to_string(i) + "." +
" Only int, float, bool, torch.Tensor, and "
"triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
void init_triton_runtime(py::module &&m) {
// m.def("current_stream", [](uint64_t device){
// return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
// });
// wrap backend_t
py::enum_<backend_t>(m, "backend")
.value("HOST", HOST)
.value("CUDA", CUDA)
.value("ROCM", ROCM)
// .value("ROCM", ROCM)
.export_values();
// enable peer-to-peer
@@ -347,6 +435,49 @@ void init_triton_runtime(py::module &&m) {
return -1;
});
m.def("launch_binary", [](py::object binary, py::list args,
py::list do_not_specialize, py::list arg_names,
py::int_ stream, py::int_ num_warps,
py::int_ num_stages, py::object grid) {
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
// get grid
py::sequence seq;
py::dict constants;
std::string params;
size_t params_size{};
parse_args(args, arg_names, params, params_size, constants);
if (!PySequence_Check(grid.ptr()))
seq = grid(constants);
else
seq = grid;
int size = seq.size();
int grid_0 = py::cast<int>(seq[0]);
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
// actually launch
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END};
uint64_t _stream = PyLong_AsLong(stream.ptr());
const int numGrids = grid_0 * grid_1 * grid_2;
if (numGrids) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps * 32, 1, 1, shared_mem,
(CUstream)_stream, nullptr, config);
}
return binary;
});
// query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
@@ -1517,7 +1648,7 @@ void init_triton_ir(py::module &&m) {
});
}
void init_translation(py::module &m) {
void init_triton_translation(py::module &m) {
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
llvm::LLVMContext llvmContext;
auto llvmModule =
@@ -1531,10 +1662,16 @@ void init_translation(py::module &m) {
});
m.def("translate_triton_gpu_to_ptx",
[](mlir::ModuleOp module, uint64_t device) -> std::string {
[](mlir::ModuleOp module, uint64_t device)
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
auto [ptxCode, cc, version, ptxasPath] =
triton::translateTritonGPUToPTX(module, device);
return ptxCode;
mlir::PassManager pm(module->getContext());
auto pass = std::make_unique<mlir::Allocation>(module);
size_t size = pass->getSharedMemorySize();
return std::make_tuple(ptxCode, size);
});
m.def("compile_ptx_to_cubin",
@@ -1550,6 +1687,16 @@ void init_translation(py::module &m) {
py::bytes bytes(cubin);
return bytes;
});
m.def(
"load_binary",
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
size_t n_shared_bytes, uint64_t dev) {
py::gil_scoped_release allow_threads;
assert(backend == CUDA); // Only CUDA is supported now.
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
},
py::return_value_policy::take_ownership);
}
void init_triton(py::module &m) {
@@ -1557,5 +1704,5 @@ void init_triton(py::module &m) {
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir")));
init_translation(subm);
init_triton_translation(subm);
}

View File

@@ -1,30 +0,0 @@
import triton
import triton.language as tl
NUM_WARPS = 4
# triton kernel
@triton.jit
def kernel(x_ptr, stride_xn,
y_ptr, stride_yn,
z_ptr, stride_zn,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
print(ret)
# TODO: base class for python end2end tests,
# runtime execution, correctness comparison etc.

View File

@@ -2,18 +2,21 @@ import torch
import triton
import triton.language as tl
import triton.runtime as runtime
# trigger the torch.device implicitly to ensure cuda context initialization
torch.zeros([10], device=torch.device('cuda'))
@triton.jit
def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
pass
def test_empty_kernel_cubin_compile():
@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
device = torch.cuda.current_device()
cubin = triton.compile(kernel,
cubin = triton.compile(empty_kernel,
"*fp32,i32,i32",
device=device,
constants={"BLOCK": 256},
@@ -21,3 +24,25 @@ def test_empty_kernel_cubin_compile():
print('cubin size:', len(cubin))
assert len(cubin) > 0
def test_empty_kernel_launch():
device = torch.cuda.current_device()
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
device=device,
constants={"BLOCK": 256},
num_warps=4,
num_stages=3)
grid = lambda META: (
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
)
A = torch.zeros([1024], device="cuda")
runtime.launch_kernel(fn=empty_kernel,
binary=binary,
grid=grid,
num_warps=4,
num_stages=3,
X=A,
stride_xm=256,
BLOCK=tl.constexpr(256))

View File

@@ -0,0 +1,27 @@
import triton
import triton.language as tl
NUM_WARPS = 4
# triton kernel
def test_vecadd_no_scf():
@triton.jit
def kernel(x_ptr, stride_xn,
y_ptr, stride_yn,
z_ptr, stride_zn,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
print(ret)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import ast
import sys
import warnings
from typing import Dict, Union
from typing import Any, Dict, Tuple, Union
import triton
import triton._C.libtriton.triton as _triton
@@ -800,11 +800,13 @@ def optimize_tritongpu_ir(mod, num_stages):
return mod
def make_ptx(mod, device):
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
:return: str
:return:
- PTX code
- shared memory alloaction size
'''
return _triton.translate_triton_gpu_to_ptx(mod, device)
@@ -819,7 +821,20 @@ def make_cubin(ptx, device):
return _triton.compile_ptx_to_cubin(ptx, device)
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
def ptx_get_kernel_name(ptx: str) -> str:
'''
Get kernel name from PTX code.
This Kernel name is required when launching the kernel.
'''
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
assert ptx
for line in ptx.split('\n'):
line = line.strip()
if line.startswith('// .globl'):
return line.split()[-1]
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
@@ -830,17 +845,18 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
assert device >= 0, "device should be provided."
ptx = make_ptx(module, device)
ptx, shem_size = make_ptx(module, device)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx
return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, device)
if output == "cubin":
return cubin
return cubin, shem_size, kernel_name
assert False

View File

@@ -1,2 +1,2 @@
from .autotuner import Config, autotune, heuristics # noqa: F401
from .jit import JITFunction, jit # noqa: F401
from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401

View File

@@ -8,18 +8,32 @@ import os
import subprocess
import tempfile
import textwrap
from typing import Any, Dict, List
import torch
import triton
import triton._C.libtriton.triton as _triton
from ..compiler import compile
from ..tools.disasm import extract
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
# -----------------------------------------------------------------------------
# Binary
# -----------------------------------------------------------------------------
VALID_BACKENDS: List[str] = (
_triton.runtime.backend.CUDA,
)
class Binary:
def __init__(self, backend, name, asm, shared_mem, num_warps):
def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int):
assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend)
self.backend = backend
self.name = name
self.asm = asm
@@ -29,11 +43,11 @@ class Binary:
class LoadedBinary:
def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend,
bin.name,
bin.asm,
bin.shared_mem,
device)
module, kernel = _triton.load_binary(bin.backend,
bin.name,
bin.asm,
bin.shared_mem,
device)
self.bin = bin
self.asm = bin.asm
self.sass = ''
@@ -241,6 +255,44 @@ class JITFunction:
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
def build_kernel(fn: JITFunction,
fn_type: str,
device: int,
constants: Dict[str, Any],
num_warps: int = 4,
num_stages: int = 3,
) -> LoadedBinary:
cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin")
assert cubin
assert kernel_name
backend = _triton.runtime.backend.CUDA
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
asm = dict(cubin=cubin)
binary = Binary(backend, kernel_name, asm, shem_size, num_warps)
loaded_binary = LoadedBinary(device, binary)
return loaded_binary
def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs):
kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs))
device = torch.cuda.current_device()
torch.cuda.set_device(device)
stream = get_cuda_stream(device)
_triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names,
stream, num_warps, num_stages, grid)
# -----------------------------------------------------------------------------
# `jit` decorator
# -----------------------------------------------------------------------------

View File

@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}}
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
@@ -58,7 +58,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
// TODO: Pending on the support of isSplat constant
// TODO: Pending on the support of isSplat constant
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other
@@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.undef
// CHECK: %[[T0:.*]] = llvm.extractvalue
// CHECK: %[[T1:.*]] = llvm.extractvalue
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T0]]
@@ -92,7 +92,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T1]]
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
return
}
}
@@ -116,7 +116,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addf
// CHECK-LABEL: basic_addf
func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
// CHECK: llvm.fadd
// CHECK: llvm.fadd
@@ -141,7 +141,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
// CHECK-LABEL: basic_program_id
func @basic_program_id() {
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32