[TESTING] Added infrastructure for executing TTGIR program and test for layout conversions (#885)

This commit is contained in:
Philippe Tillet
2022-11-18 07:46:45 +01:00
committed by GitHub
parent 9ea6135eb5
commit dab4855bdf
6 changed files with 243 additions and 67 deletions

View File

@@ -1399,37 +1399,35 @@ struct BroadcastOpConversion
Value result = op.result();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto resultTy = result.getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(srcLayout && (srcLayout == resultLayout) &&
"Unexpected layout of BroadcastOp");
auto srcLayout = srcTy.getEncoding();
auto resultLayout = resultTy.getEncoding();
auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
assert(rank == resultTy.getRank());
auto order = srcLayout.getOrder();
auto order = triton::gpu::getOrder(srcLayout);
SmallVector<int64_t> srcLogicalShape(2 * rank);
SmallVector<unsigned> srcLogicalOrder(2 * rank);
SmallVector<int64_t> resultLogicalShape(2 * rank);
SmallVector<unsigned> broadcastDims;
for (unsigned d = 0; d < rank; ++d) {
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
resultLayout.getWarpsPerCTA()[d];
unsigned resultShapePerCTA = triton::gpu::getSizePerThread(resultLayout)[d] *
triton::gpu::getThreadsPerWarp(resultLayout)[d] *
triton::gpu::getWarpsPerCTA(resultLayout)[d];
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
if (srcShape[d] != resultShape[d]) {
assert(srcShape[d] == 1);
broadcastDims.push_back(d);
srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] =
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]);
} else {
srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
srcLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
}
resultLogicalShape[d] = numCtas;
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
resultLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
srcLogicalOrder[d] = order[d] + rank;
srcLogicalOrder[d + rank] = order[d];

View File

@@ -163,7 +163,19 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::Type>(m, "type")
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16);
.def("is_fp16", &mlir::Type::isF16)
.def("__str__", [](mlir::Type &self) {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return os.str();
});
py::class_<mlir::FunctionType>(m, "function_type")
.def("param_types", [](mlir::FunctionType &self) {
return std::vector<mlir::Type>(self.getInputs().begin(),
self.getInputs().end());
});
py::class_<mlir::Value>(m, "value")
.def("set_attr",
@@ -314,6 +326,13 @@ void init_triton_ir(py::module &&m) {
.def("get_function",
[](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
})
.def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp {
llvm::SmallVector<mlir::FuncOp> funcs;
self.walk([&](mlir::FuncOp func) { funcs.push_back(func); });
if (funcs.size() != 1)
throw std::runtime_error("Expected a single function");
return funcs[0];
});
m.def(
@@ -363,6 +382,7 @@ void init_triton_ir(py::module &&m) {
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
.def_property_readonly("type", &mlir::FuncOp::getType)
.def("reset_type", &mlir::FuncOp::setType);
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -1274,8 +1294,8 @@ void init_triton_ir(py::module &&m) {
void init_triton_translation(py::module &m) {
using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
m.def("get_shared_memory_size", [](mlir::ModuleOp mod) {
auto shared = mod->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
return shared.getInt();
});

View File

@@ -0,0 +1,91 @@
import triton
import triton.language as tl
import torch
import pytest
from .test_core import numpy_random, to_triton
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
def __str__(self):
return f"#triton_gpu.mma<{{version={self.version}, warpsPerCTA={self.warps_per_cta}}}>"
class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=2, warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=2, warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1])
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
}
}
"""
x = to_triton(numpy_random(shape, dtype_str=dtype))
z = torch.empty_like(x)
# write the IR to a temporary file using mkstemp
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
kernel[(1,1,1)](x.data_ptr(), z.data_ptr())
assert torch.equal(z, x)

View File

@@ -16,7 +16,7 @@ def test_empty_kernel_cubin_compile():
device = torch.cuda.current_device()
kernel = triton.compile(empty_kernel,
"*fp32,i32,i32",
signature="*fp32,i32,i32",
device=device,
constants={"BLOCK": 256})

View File

@@ -1028,7 +1028,7 @@ def binary_name_to_header_name(name):
return f"{name}.h"
def generate_launcher(identifier, constants, signature):
def generate_launcher(constants, signature):
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
def _extracted_type(ty):
@@ -1184,6 +1184,9 @@ class CacheManager:
def put(self, data, filename, binary=True):
if not self.cache_dir:
return
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
with FileLock(self.lock_path):
@@ -1296,16 +1299,8 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
cache_manager.put(data, file_name, True if isinstance(data, bytes) else data)
return module, md5, True, False
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
# we get the kernel, i.e. the first function generated in the module
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
# cache manager
name = fn.__name__
#
def make_stub(name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(signature, constants)
so_cache_manager = CacheManager(so_cache_key)
@@ -1313,57 +1308,129 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
# retrieve stub from cache if it exists
if not so_cache_manager.has_file(so_name):
with tempfile.TemporaryDirectory() as tmpdir:
src = generate_launcher(name, constants, signature)
src = generate_launcher(constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(fn.__name__, src_path, tmpdir)
so = _build(name, src_path, tmpdir)
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
so_path = so_cache_manager._make_path(so_name)
return so_cache_manager._make_path(so_name)
def convert_type_repr(x):
match = re.search('!tt\.ptr<(.*)>', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return x
def make_hash(fn, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, **kwargs):
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
# has to be a path to a file
context = _triton.ir.context()
asm, md5 = dict(), dict()
constants = kwargs.get("constants", dict())
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
signature = kwargs["signature"]
if configs is None:
configs = [instance_descriptor()]
assert len(configs) == 1
kwargs["configs"] = configs
name = fn.__name__
first_stage = 0
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
kwargs["signature"] = signature
else:
assert isinstance(fn, str)
name, ir = os.path.basename(fn).split(".")
assert ir == "ttgir"
asm[ir] = _triton.ir.parse_mlir_module(fn, context)
function = asm[ir].get_single_function()
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = 2
# cache manager
so_path = make_stub(name, signature, constants)
# create cache manager
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
fn_cache_manager = CacheManager(make_hash(fn, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
# initialize compilation params
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
# load metadata if any
metadata = None
if fn_cache_manager.has_file(f'{name}.json'):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
context = _triton.ir.context()
force_compile = False
# ast -> triton-ir (or read from cache)
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
# triton-ir -> triton-gpu-ir (or read from cache)
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
# triton-gpu-ir -> llvm-ir (or read from cache)
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
if llvm_cached:
shmem_size = metadata["shared"]
else:
shmem_size = _triton.get_shared_memory_size(ttgir)
# llvm-ir -> ptx (or read from cache)
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
run_if_found = lambda path: Path(path).read_text(),
run_if_not_found = lambda: llir_to_ptx(llir))
# ptx -> cubin (or read from cache)
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
# dump new metadata
kernel_name = ptx_get_kernel_name(ptx)
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
"llir": llir_md5,
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
# build compilation stages
stages = {
"ast" : (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs)),
"ptx": (lambda path: Path(path).read_text(),
llir_to_ptx),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, device))
}
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile) in list(stages.items())[first_stage:]:
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
# return handle to compiled kernel
return CompiledKernel(so_path, metadata, asm)
@@ -1395,7 +1462,7 @@ class CompiledKernel:
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return
return runner
def get_sass(self, fun=None):
if 'sass' in self.asm:

View File

@@ -275,7 +275,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
raise TypeError(f"Callable constexpr at index {i} is not supported")
device = 0
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
self.cache[key] = bin