From 10ba51c3bbcd19b18fe8b34998a76f16d2c3b38e Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Sat, 20 Aug 2022 01:46:01 +0800 Subject: [PATCH] [FRONTEND] add python e2e launch empty kernel test (#68) --- include/triton/Conversion/MLIRTypes.h | 4 +- lib/Analysis/Allocation.cpp | 2 - .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 5 + python/src/triton.cc | 169 ++++++++++++++++-- python/test/vecadd_no_scf.py | 30 ---- python/tests/test_compiler.py | 33 +++- python/tests/test_vecadd_no_scf.py | 27 +++ python/triton/compiler.py | 32 +++- python/triton/runtime/__init__.py | 2 +- python/triton/runtime/jit.py | 64 ++++++- test/Conversion/tritongpu_to_llvm.mlir | 12 +- 11 files changed, 311 insertions(+), 69 deletions(-) delete mode 100644 python/test/vecadd_no_scf.py create mode 100644 python/tests/test_vecadd_no_scf.py diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index 5f33ced4b..78c1bea33 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -19,7 +19,9 @@ Type i8Ty(MLIRContext *ctx) { Type u32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32, IntegerType::Signless); } -Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1); } +Type u1Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 1, IntegerType::Unsigned); +} // Float types Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index efe13f054..80266ff29 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -2,8 +2,6 @@ #include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/SliceAnalysis.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 89bff6851..0d8df10ba 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -171,6 +171,11 @@ struct FuncOpConversion : public FuncOpConversionBase { return failure(); auto ctx = funcOp->getContext(); + + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr(NVVMMetadataField::Kernel, + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + // Set an attribute for maxntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. newFuncOp->setAttr( diff --git a/python/src/triton.cc b/python/src/triton.cc index a3d472c23..c2ef6b4ba 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -6,9 +6,11 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Allocation.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -254,19 +256,105 @@ void parse_args(py::list &args, py::list do_not_specialize, params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); } -// +void parse_args(py::list &args, py::list &arg_names, std::string ¶ms, + size_t ¶ms_size, py::dict constants) { + char *params_ptr = params.data(); + + size_t len = PyList_Size(args.ptr()); + for (int i = 0; i < len; i++) { + py::object arg = args[i]; + auto arg_ptr = arg.ptr(); + + if (PyLong_Check(arg_ptr)) { + int overflow{}; + long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); + + if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) { + params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + } else if (!overflow && 0x8000'0000LL <= value && + value <= 0xFFFF'FFFFLL) { + params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + } else if (!overflow) { + params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8)); + std::memcpy(params_ptr, &value, 8); + params_ptr += 8; + } else { + if (PyErr_Occurred()) { + throw std::logic_error("An error occurred?"); + } + unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr); + if (PyErr_Occurred()) { + throw std::runtime_error("integer overflow in argument: " + + std::string(py::str(arg))); + } + params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8)); + std::memcpy(params_ptr, &unsigned_value, 8); + params_ptr += 8; + } + continue; + } + + if (PyFloat_Check(arg_ptr)) { + float value = PyFloat_AsDouble(arg_ptr); + params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + continue; + } + + // argument is `bool` + if (PyBool_Check(arg_ptr)) { + bool value = arg_ptr == Py_True ? true : false; + std::memcpy(params_ptr, &value, 1); + params_ptr += 1; + continue; + } + // argument is torch.tensor, get data_ptr as memory address + if (py::hasattr(arg, "data_ptr")) { + py::object data_ptr = arg.attr("data_ptr")(); + long value = data_ptr.cast(); + params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8)); + // copy param + std::memcpy(params_ptr, &value, 8); + params_ptr += 8; + // udpate cache key + continue; + } + // argument is `constexpr` + if (py::hasattr(arg, "value")) { + py::object value = arg.attr("value"); + py::object name = arg_names[i]; + constants[name] = value; + continue; + } + // argument is `LoadedBinary` + if (py::hasattr(arg, "get_sass")) { + // Do nothing, just a placeholder here to indicate validity. + continue; + } + + std::string ty_str = + arg.attr("__class__").attr("__name__").cast(); + std::string err_msg = "Received type '" + ty_str + "' for argument " + + std::to_string(i) + "." + + " Only int, float, bool, torch.Tensor, and " + "triton.language.constexpr are supported."; + throw std::runtime_error(err_msg); + } + + params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); +} void init_triton_runtime(py::module &&m) { - - // m.def("current_stream", [](uint64_t device){ - // return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream()); - // }); - // wrap backend_t py::enum_(m, "backend") .value("HOST", HOST) .value("CUDA", CUDA) - .value("ROCM", ROCM) + // .value("ROCM", ROCM) .export_values(); // enable peer-to-peer @@ -347,6 +435,49 @@ void init_triton_runtime(py::module &&m) { return -1; }); + m.def("launch_binary", [](py::object binary, py::list args, + py::list do_not_specialize, py::list arg_names, + py::int_ stream, py::int_ num_warps, + py::int_ num_stages, py::object grid) { + long _num_warps = PyLong_AsLong(num_warps.ptr()); + long _num_stages = PyLong_AsLong(num_stages.ptr()); + + // get grid + py::sequence seq; + py::dict constants; + std::string params; + size_t params_size{}; + parse_args(args, arg_names, params, params_size, constants); + if (!PySequence_Check(grid.ptr())) + seq = grid(constants); + else + seq = grid; + + int size = seq.size(); + int grid_0 = py::cast(seq[0]); + int grid_1 = size < 2 ? 1 : py::cast(seq[1]); + int grid_2 = size < 3 ? 1 : py::cast(seq[2]); + + uint64_t kernel = py::cast(binary.attr("kernel")); + uint64_t shared_mem = py::cast(binary.attr("shared_mem")); + + // actually launch + void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(), + CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size, + CU_LAUNCH_PARAM_END}; + uint64_t _stream = PyLong_AsLong(stream.ptr()); + const int numGrids = grid_0 * grid_1 * grid_2; + if (numGrids) { + // release the gil in case the enqueue blocks + // cuda will block if too many ops are enqueued + py::gil_scoped_release allow_threads; + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, + _num_warps * 32, 1, 1, shared_mem, + (CUstream)_stream, nullptr, config); + } + return binary; + }); + // query maximum shared memory m.def("max_shared_memory", [](backend_t backend, uint64_t device) { if (backend == HOST) @@ -1517,7 +1648,7 @@ void init_triton_ir(py::module &&m) { }); } -void init_translation(py::module &m) { +void init_triton_translation(py::module &m) { m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string { llvm::LLVMContext llvmContext; auto llvmModule = @@ -1531,10 +1662,16 @@ void init_translation(py::module &m) { }); m.def("translate_triton_gpu_to_ptx", - [](mlir::ModuleOp module, uint64_t device) -> std::string { + [](mlir::ModuleOp module, uint64_t device) + -> std::tuple { auto [ptxCode, cc, version, ptxasPath] = triton::translateTritonGPUToPTX(module, device); - return ptxCode; + + mlir::PassManager pm(module->getContext()); + auto pass = std::make_unique(module); + size_t size = pass->getSharedMemorySize(); + + return std::make_tuple(ptxCode, size); }); m.def("compile_ptx_to_cubin", @@ -1550,6 +1687,16 @@ void init_translation(py::module &m) { py::bytes bytes(cubin); return bytes; }); + + m.def( + "load_binary", + [](backend_t backend, const std::string &name, asm_map_t &asm_map, + size_t n_shared_bytes, uint64_t dev) { + py::gil_scoped_release allow_threads; + assert(backend == CUDA); // Only CUDA is supported now. + return cu_load_binary(name, asm_map, n_shared_bytes, dev); + }, + py::return_value_policy::take_ownership); } void init_triton(py::module &m) { @@ -1557,5 +1704,5 @@ void init_triton(py::module &m) { // init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); - init_translation(subm); + init_triton_translation(subm); } diff --git a/python/test/vecadd_no_scf.py b/python/test/vecadd_no_scf.py deleted file mode 100644 index 573194a59..000000000 --- a/python/test/vecadd_no_scf.py +++ /dev/null @@ -1,30 +0,0 @@ -import triton -import triton.language as tl - -NUM_WARPS = 4 - -# triton kernel - - -@triton.jit -def kernel(x_ptr, stride_xn, - y_ptr, stride_yn, - z_ptr, stride_zn, - BLOCK_SIZE_N: tl.constexpr): - pid = tl.program_id(axis=0) - offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - x_ptrs = x_ptr + offset - y_ptrs = y_ptr + offset - x = tl.load(x_ptrs) - y = tl.load(y_ptrs) - z = x + y - z_ptrs = z_ptr + offset - tl.store(z_ptrs, z) - - -ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx") - -print(ret) - -# TODO: base class for python end2end tests, -# runtime execution, correctness comparison etc. diff --git a/python/tests/test_compiler.py b/python/tests/test_compiler.py index 1ef75ccdc..452d33f63 100644 --- a/python/tests/test_compiler.py +++ b/python/tests/test_compiler.py @@ -2,18 +2,21 @@ import torch import triton import triton.language as tl +import triton.runtime as runtime # trigger the torch.device implicitly to ensure cuda context initialization torch.zeros([10], device=torch.device('cuda')) +@triton.jit +def empty_kernel(X, stride_xm, BLOCK: tl.constexpr): + pass + + def test_empty_kernel_cubin_compile(): - @triton.jit - def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): - pass device = torch.cuda.current_device() - cubin = triton.compile(kernel, + cubin = triton.compile(empty_kernel, "*fp32,i32,i32", device=device, constants={"BLOCK": 256}, @@ -21,3 +24,25 @@ def test_empty_kernel_cubin_compile(): print('cubin size:', len(cubin)) assert len(cubin) > 0 + + +def test_empty_kernel_launch(): + device = torch.cuda.current_device() + binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32", + device=device, + constants={"BLOCK": 256}, + num_warps=4, + num_stages=3) + grid = lambda META: ( + triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']), + ) + + A = torch.zeros([1024], device="cuda") + runtime.launch_kernel(fn=empty_kernel, + binary=binary, + grid=grid, + num_warps=4, + num_stages=3, + X=A, + stride_xm=256, + BLOCK=tl.constexpr(256)) diff --git a/python/tests/test_vecadd_no_scf.py b/python/tests/test_vecadd_no_scf.py new file mode 100644 index 000000000..995ef5fae --- /dev/null +++ b/python/tests/test_vecadd_no_scf.py @@ -0,0 +1,27 @@ +import triton +import triton.language as tl + +NUM_WARPS = 4 + +# triton kernel + + +def test_vecadd_no_scf(): + @triton.jit + def kernel(x_ptr, stride_xn, + y_ptr, stride_yn, + z_ptr, stride_zn, + BLOCK_SIZE_N: tl.constexpr): + pid = tl.program_id(axis=0) + offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + x = tl.load(x_ptrs) + y = tl.load(y_ptrs) + z = x + y + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z) + + ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx") + + print(ret) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index b70348d7f..45213d84a 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -3,7 +3,7 @@ from __future__ import annotations import ast import sys import warnings -from typing import Dict, Union +from typing import Any, Dict, Tuple, Union import triton import triton._C.libtriton.triton as _triton @@ -800,11 +800,13 @@ def optimize_tritongpu_ir(mod, num_stages): return mod -def make_ptx(mod, device): +def make_ptx(mod: Any, device: int) -> Tuple[str, int]: ''' Translate TritonGPU module to PTX code. :param mod: a TritonGPU dialect module - :return: str + :return: + - PTX code + - shared memory alloaction size ''' return _triton.translate_triton_gpu_to_ptx(mod, device) @@ -819,7 +821,20 @@ def make_cubin(ptx, device): return _triton.compile_ptx_to_cubin(ptx, device) -def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"): +def ptx_get_kernel_name(ptx: str) -> str: + ''' + Get kernel name from PTX code. + This Kernel name is required when launching the kernel. + ''' + # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. + assert ptx + for line in ptx.split('\n'): + line = line.strip() + if line.startswith('// .globl'): + return line.split()[-1] + + +def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]: valid_outputs = ("ttir", "ttgir", "ptx", "cubin") assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) # triton-ir @@ -830,17 +845,18 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w # tritongpu-ir module = make_tritongpu_ir(module, num_warps) module = optimize_tritongpu_ir(module, num_stages) + if output == "ttgir": return module.str() assert device >= 0, "device should be provided." - - ptx = make_ptx(module, device) + ptx, shem_size = make_ptx(module, device) + kernel_name = ptx_get_kernel_name(ptx) if output == "ptx": - return ptx + return ptx, shem_size, kernel_name cubin = make_cubin(ptx, device) if output == "cubin": - return cubin + return cubin, shem_size, kernel_name assert False diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index 296d750ee..3be401163 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,2 +1,2 @@ from .autotuner import Config, autotune, heuristics # noqa: F401 -from .jit import JITFunction, jit # noqa: F401 +from .jit import JITFunction, build_kernel, jit, launch_kernel # noqa: F401 diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 6c9547b8f..f99cf9a77 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -8,18 +8,32 @@ import os import subprocess import tempfile import textwrap +from typing import Any, Dict, List + +import torch import triton import triton._C.libtriton.triton as _triton +from ..compiler import compile from ..tools.disasm import extract +try: + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +except ImportError: + get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + # ----------------------------------------------------------------------------- # Binary # ----------------------------------------------------------------------------- +VALID_BACKENDS: List[str] = ( + _triton.runtime.backend.CUDA, +) + class Binary: - def __init__(self, backend, name, asm, shared_mem, num_warps): + def __init__(self, backend: str, name: str, asm: Dict[str, str], shared_mem: int, num_warps: int): + assert backend in VALID_BACKENDS, "backend should within [%s], but get a \"%s\"" % (', '.join(VALID_BACKENDS), backend) self.backend = backend self.name = name self.asm = asm @@ -29,11 +43,11 @@ class Binary: class LoadedBinary: def __init__(self, device: int, bin: Binary): - module, kernel = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, - device) + module, kernel = _triton.load_binary(bin.backend, + bin.name, + bin.asm, + bin.shared_mem, + device) self.bin = bin self.asm = bin.asm self.sass = '' @@ -241,6 +255,44 @@ class JITFunction: def __repr__(self): return f"JITFunction({self.module}:{self.fn.__name__})" + +def build_kernel(fn: JITFunction, + fn_type: str, + device: int, + constants: Dict[str, Any], + num_warps: int = 4, + num_stages: int = 3, + ) -> LoadedBinary: + cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin") + assert cubin + assert kernel_name + + backend = _triton.runtime.backend.CUDA + + max_shared_memory = _triton.runtime.max_shared_memory(backend, device) + assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size) + + asm = dict(cubin=cubin) + binary = Binary(backend, kernel_name, asm, shem_size, num_warps) + loaded_binary = LoadedBinary(device, binary) + return loaded_binary + + +def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs): + kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()} + wargs = list(wargs) + for i, pos in enumerate(sorted(kwargs)): + wargs.insert(pos + i, kwargs[pos]) + assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs)) + + device = torch.cuda.current_device() + torch.cuda.set_device(device) + stream = get_cuda_stream(device) + + _triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names, + stream, num_warps, num_stages, grid) + + # ----------------------------------------------------------------------------- # `jit` decorator # ----------------------------------------------------------------------------- diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 63bb5e89d..473607bc3 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) // Here the 128 comes from the 4 in module attribute multiples 32 -// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}} +// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}} func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return @@ -58,7 +58,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -// TODO: Pending on the support of isSplat constant +// TODO: Pending on the support of isSplat constant #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other @@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.mlir.undef // CHECK: %[[T0:.*]] = llvm.extractvalue // CHECK: %[[T1:.*]] = llvm.extractvalue - %0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2> + %0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2> // CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue %[[T0]] // CHECK: llvm.insertvalue %[[T0]] @@ -92,7 +92,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.insertvalue %[[T1]] // CHECK: llvm.insertvalue %[[T1]] // CHECK: llvm.insertvalue %[[T1]] - %1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2> + %1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2> return } } @@ -116,7 +116,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: basic_addf + // CHECK-LABEL: basic_addf func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { // CHECK: llvm.fadd // CHECK: llvm.fadd @@ -141,7 +141,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: basic_program_id + // CHECK-LABEL: basic_program_id func @basic_program_id() { // CHECK: nvvm.read.ptx.sreg.ctaid.x : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32