[FRONTEND] add python e2e launch empty kernel test (#68)
This commit is contained in:
@@ -19,7 +19,9 @@ Type i8Ty(MLIRContext *ctx) {
|
||||
Type u32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Signless);
|
||||
}
|
||||
Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1); }
|
||||
Type u1Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
|
||||
}
|
||||
|
||||
// Float types
|
||||
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||
|
@@ -2,8 +2,6 @@
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
@@ -171,6 +171,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
return failure();
|
||||
|
||||
auto ctx = funcOp->getContext();
|
||||
|
||||
// Set an attribute to indicate this function is a kernel entry.
|
||||
newFuncOp->setAttr(NVVMMetadataField::Kernel,
|
||||
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
|
||||
|
||||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||
// for `nvvm.annotation` metadata.
|
||||
newFuncOp->setAttr(
|
||||
|
@@ -6,9 +6,11 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
@@ -254,19 +256,105 @@ void parse_args(py::list &args, py::list do_not_specialize,
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
//
|
||||
void parse_args(py::list &args, py::list &arg_names, std::string ¶ms,
|
||||
size_t ¶ms_size, py::dict constants) {
|
||||
char *params_ptr = params.data();
|
||||
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
for (int i = 0; i < len; i++) {
|
||||
py::object arg = args[i];
|
||||
auto arg_ptr = arg.ptr();
|
||||
|
||||
if (PyLong_Check(arg_ptr)) {
|
||||
int overflow{};
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
|
||||
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow && 0x8000'0000LL <= value &&
|
||||
value <= 0xFFFF'FFFFLL) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
} else {
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::logic_error("An error occurred?");
|
||||
}
|
||||
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::runtime_error("integer overflow in argument: " +
|
||||
std::string(py::str(arg)));
|
||||
}
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &unsigned_value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (PyFloat_Check(arg_ptr)) {
|
||||
float value = PyFloat_AsDouble(arg_ptr);
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// argument is `bool`
|
||||
if (PyBool_Check(arg_ptr)) {
|
||||
bool value = arg_ptr == Py_True ? true : false;
|
||||
std::memcpy(params_ptr, &value, 1);
|
||||
params_ptr += 1;
|
||||
continue;
|
||||
}
|
||||
// argument is torch.tensor, get data_ptr as memory address
|
||||
if (py::hasattr(arg, "data_ptr")) {
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
// copy param
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
// udpate cache key
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
if (py::hasattr(arg, "value")) {
|
||||
py::object value = arg.attr("value");
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
continue;
|
||||
}
|
||||
// argument is `LoadedBinary`
|
||||
if (py::hasattr(arg, "get_sass")) {
|
||||
// Do nothing, just a placeholder here to indicate validity.
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string ty_str =
|
||||
arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " +
|
||||
std::to_string(i) + "." +
|
||||
" Only int, float, bool, torch.Tensor, and "
|
||||
"triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
|
||||
// m.def("current_stream", [](uint64_t device){
|
||||
// return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
|
||||
// });
|
||||
|
||||
// wrap backend_t
|
||||
py::enum_<backend_t>(m, "backend")
|
||||
.value("HOST", HOST)
|
||||
.value("CUDA", CUDA)
|
||||
.value("ROCM", ROCM)
|
||||
// .value("ROCM", ROCM)
|
||||
.export_values();
|
||||
|
||||
// enable peer-to-peer
|
||||
@@ -347,6 +435,49 @@ void init_triton_runtime(py::module &&m) {
|
||||
return -1;
|
||||
});
|
||||
|
||||
m.def("launch_binary", [](py::object binary, py::list args,
|
||||
py::list do_not_specialize, py::list arg_names,
|
||||
py::int_ stream, py::int_ num_warps,
|
||||
py::int_ num_stages, py::object grid) {
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
|
||||
// get grid
|
||||
py::sequence seq;
|
||||
py::dict constants;
|
||||
std::string params;
|
||||
size_t params_size{};
|
||||
parse_args(args, arg_names, params, params_size, constants);
|
||||
if (!PySequence_Check(grid.ptr()))
|
||||
seq = grid(constants);
|
||||
else
|
||||
seq = grid;
|
||||
|
||||
int size = seq.size();
|
||||
int grid_0 = py::cast<int>(seq[0]);
|
||||
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
|
||||
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
|
||||
|
||||
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
|
||||
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
|
||||
|
||||
// actually launch
|
||||
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size,
|
||||
CU_LAUNCH_PARAM_END};
|
||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||
const int numGrids = grid_0 * grid_1 * grid_2;
|
||||
if (numGrids) {
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
py::gil_scoped_release allow_threads;
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps * 32, 1, 1, shared_mem,
|
||||
(CUstream)_stream, nullptr, config);
|
||||
}
|
||||
return binary;
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
@@ -1517,7 +1648,7 @@ void init_triton_ir(py::module &&m) {
|
||||
});
|
||||
}
|
||||
|
||||
void init_translation(py::module &m) {
|
||||
void init_triton_translation(py::module &m) {
|
||||
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
@@ -1531,10 +1662,16 @@ void init_translation(py::module &m) {
|
||||
});
|
||||
|
||||
m.def("translate_triton_gpu_to_ptx",
|
||||
[](mlir::ModuleOp module, uint64_t device) -> std::string {
|
||||
[](mlir::ModuleOp module, uint64_t device)
|
||||
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
|
||||
auto [ptxCode, cc, version, ptxasPath] =
|
||||
triton::translateTritonGPUToPTX(module, device);
|
||||
return ptxCode;
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
auto pass = std::make_unique<mlir::Allocation>(module);
|
||||
size_t size = pass->getSharedMemorySize();
|
||||
|
||||
return std::make_tuple(ptxCode, size);
|
||||
});
|
||||
|
||||
m.def("compile_ptx_to_cubin",
|
||||
@@ -1550,6 +1687,16 @@ void init_translation(py::module &m) {
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"load_binary",
|
||||
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
|
||||
size_t n_shared_bytes, uint64_t dev) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
assert(backend == CUDA); // Only CUDA is supported now.
|
||||
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
@@ -1557,5 +1704,5 @@ void init_triton(py::module &m) {
|
||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_translation(subm);
|
||||
init_triton_translation(subm);
|
||||
}
|
||||
|
@@ -1,30 +0,0 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
NUM_WARPS = 4
|
||||
|
||||
# triton kernel
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr, stride_xn,
|
||||
y_ptr, stride_yn,
|
||||
z_ptr, stride_zn,
|
||||
BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
||||
|
||||
print(ret)
|
||||
|
||||
# TODO: base class for python end2end tests,
|
||||
# runtime execution, correctness comparison etc.
|
@@ -2,18 +2,21 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
|
||||
def test_empty_kernel_cubin_compile():
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
cubin = triton.compile(kernel,
|
||||
cubin = triton.compile(empty_kernel,
|
||||
"*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256},
|
||||
@@ -21,3 +24,25 @@ def test_empty_kernel_cubin_compile():
|
||||
|
||||
print('cubin size:', len(cubin))
|
||||
assert len(cubin) > 0
|
||||
|
||||
|
||||
def test_empty_kernel_launch():
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK": 256},
|
||||
num_warps=4,
|
||||
num_stages=3)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
|
||||
)
|
||||
|
||||
A = torch.zeros([1024], device="cuda")
|
||||
runtime.launch_kernel(fn=empty_kernel,
|
||||
binary=binary,
|
||||
grid=grid,
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
X=A,
|
||||
stride_xm=256,
|
||||
BLOCK=tl.constexpr(256))
|
||||
|
27
python/tests/test_vecadd_no_scf.py
Normal file
27
python/tests/test_vecadd_no_scf.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
NUM_WARPS = 4
|
||||
|
||||
# triton kernel
|
||||
|
||||
|
||||
def test_vecadd_no_scf():
|
||||
@triton.jit
|
||||
def kernel(x_ptr, stride_xn,
|
||||
y_ptr, stride_yn,
|
||||
z_ptr, stride_zn,
|
||||
BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
||||
|
||||
print(ret)
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import ast
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
@@ -800,11 +800,13 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod, device):
|
||||
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return: str
|
||||
:return:
|
||||
- PTX code
|
||||
- shared memory alloaction size
|
||||
'''
|
||||
return _triton.translate_triton_gpu_to_ptx(mod, device)
|
||||
|
||||
@@ -819,7 +821,20 @@ def make_cubin(ptx, device):
|
||||
return _triton.compile_ptx_to_cubin(ptx, device)
|
||||
|
||||
|
||||
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
def ptx_get_kernel_name(ptx: str) -> str:
|
||||
'''
|
||||
Get kernel name from PTX code.
|
||||
This Kernel name is required when launching the kernel.
|
||||
'''
|
||||
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
|
||||
assert ptx
|
||||
for line in ptx.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('// .globl'):
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# triton-ir
|
||||
@@ -830,17 +845,18 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
|
||||
ptx = make_ptx(module, device)
|
||||
ptx, shem_size = make_ptx(module, device)
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
if output == "ptx":
|
||||
return ptx
|
||||
return ptx, shem_size, kernel_name
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
if output == "cubin":
|
||||
return cubin
|
||||
return cubin, shem_size, kernel_name
|
||||
|
||||
assert False
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from .autotuner import Config, autotune, heuristics # noqa: F401
|
||||
from .jit import JITFunction, jit # noqa: F401
|
||||
from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401
|
||||
|
@@ -8,18 +8,32 @@ import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..compiler import compile
|
||||
from ..tools.disasm import extract
|
||||
|
||||
try:
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Binary
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
VALID_BACKENDS: List[str] = (
|
||||
_triton.runtime.backend.CUDA,
|
||||
)
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, backend, name, asm, shared_mem, num_warps):
|
||||
def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int):
|
||||
assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend)
|
||||
self.backend = backend
|
||||
self.name = name
|
||||
self.asm = asm
|
||||
@@ -29,11 +43,11 @@ class Binary:
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
module, kernel = _triton.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
self.sass = ''
|
||||
@@ -241,6 +255,44 @@ class JITFunction:
|
||||
def __repr__(self):
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
|
||||
def build_kernel(fn: JITFunction,
|
||||
fn_type: str,
|
||||
device: int,
|
||||
constants: Dict[str, Any],
|
||||
num_warps: int = 4,
|
||||
num_stages: int = 3,
|
||||
) -> LoadedBinary:
|
||||
cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin")
|
||||
assert cubin
|
||||
assert kernel_name
|
||||
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
|
||||
|
||||
asm = dict(cubin=cubin)
|
||||
binary = Binary(backend, kernel_name, asm, shem_size, num_warps)
|
||||
loaded_binary = LoadedBinary(device, binary)
|
||||
return loaded_binary
|
||||
|
||||
|
||||
def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs):
|
||||
kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
wargs = list(wargs)
|
||||
for i, pos in enumerate(sorted(kwargs)):
|
||||
wargs.insert(pos + i, kwargs[pos])
|
||||
assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs))
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
stream = get_cuda_stream(device)
|
||||
|
||||
_triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names,
|
||||
stream, num_warps, num_stages, grid)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# `jit` decorator
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}}
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
// CHECK: llvm.return
|
||||
@@ -58,7 +58,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Pending on the support of isSplat constant
|
||||
// TODO: Pending on the support of isSplat constant
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: masked_load_const_other
|
||||
@@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: %[[T0:.*]] = llvm.extractvalue
|
||||
// CHECK: %[[T1:.*]] = llvm.extractvalue
|
||||
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
|
||||
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
@@ -92,7 +92,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
|
||||
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addf
|
||||
// CHECK-LABEL: basic_addf
|
||||
func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
||||
// CHECK: llvm.fadd
|
||||
// CHECK: llvm.fadd
|
||||
@@ -141,7 +141,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
// CHECK-LABEL: basic_program_id
|
||||
func @basic_program_id() {
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
|
Reference in New Issue
Block a user