[FRONTEND] add python e2e launch empty kernel test (#68)

This commit is contained in:
Yan Chunwei
2022-08-20 01:46:01 +08:00
committed by GitHub
parent 9aa00249a6
commit 10ba51c3bb
11 changed files with 311 additions and 69 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import ast
import sys
import warnings
from typing import Dict, Union
from typing import Any, Dict, Tuple, Union
import triton
import triton._C.libtriton.triton as _triton
@@ -800,11 +800,13 @@ def optimize_tritongpu_ir(mod, num_stages):
return mod
def make_ptx(mod, device):
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
:return: str
:return:
- PTX code
- shared memory alloaction size
'''
return _triton.translate_triton_gpu_to_ptx(mod, device)
@@ -819,7 +821,20 @@ def make_cubin(ptx, device):
return _triton.compile_ptx_to_cubin(ptx, device)
def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
def ptx_get_kernel_name(ptx: str) -> str:
'''
Get kernel name from PTX code.
This Kernel name is required when launching the kernel.
'''
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
assert ptx
for line in ptx.split('\n'):
line = line.strip()
if line.startswith('// .globl'):
return line.split()[-1]
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
@@ -830,17 +845,18 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
assert device >= 0, "device should be provided."
ptx = make_ptx(module, device)
ptx, shem_size = make_ptx(module, device)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx
return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, device)
if output == "cubin":
return cubin
return cubin, shem_size, kernel_name
assert False