[FRONTEND] add python e2e launch empty kernel test (#68)
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import ast
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
@@ -800,11 +800,13 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod, device):
|
||||
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return: str
|
||||
:return:
|
||||
- PTX code
|
||||
- shared memory alloaction size
|
||||
'''
|
||||
return _triton.translate_triton_gpu_to_ptx(mod, device)
|
||||
|
||||
@@ -819,7 +821,20 @@ def make_cubin(ptx, device):
|
||||
return _triton.compile_ptx_to_cubin(ptx, device)
|
||||
|
||||
|
||||
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
def ptx_get_kernel_name(ptx: str) -> str:
|
||||
'''
|
||||
Get kernel name from PTX code.
|
||||
This Kernel name is required when launching the kernel.
|
||||
'''
|
||||
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
|
||||
assert ptx
|
||||
for line in ptx.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('// .globl'):
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# triton-ir
|
||||
@@ -830,17 +845,18 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
|
||||
ptx = make_ptx(module, device)
|
||||
ptx, shem_size = make_ptx(module, device)
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
if output == "ptx":
|
||||
return ptx
|
||||
return ptx, shem_size, kernel_name
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
if output == "cubin":
|
||||
return cubin
|
||||
return cubin, shem_size, kernel_name
|
||||
|
||||
assert False
|
||||
|
Reference in New Issue
Block a user