diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index b90a8d933..137512c69 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1399,37 +1399,35 @@ struct BroadcastOpConversion Value result = op.result(); auto srcTy = op.src().getType().cast(); auto resultTy = result.getType().cast(); - auto srcLayout = srcTy.getEncoding().dyn_cast(); - auto resultLayout = resultTy.getEncoding().dyn_cast(); - assert(srcLayout && (srcLayout == resultLayout) && - "Unexpected layout of BroadcastOp"); + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); auto srcShape = srcTy.getShape(); auto resultShape = resultTy.getShape(); unsigned rank = srcTy.getRank(); assert(rank == resultTy.getRank()); - auto order = srcLayout.getOrder(); + auto order = triton::gpu::getOrder(srcLayout); SmallVector srcLogicalShape(2 * rank); SmallVector srcLogicalOrder(2 * rank); SmallVector resultLogicalShape(2 * rank); SmallVector broadcastDims; for (unsigned d = 0; d < rank; ++d) { - unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] * - resultLayout.getThreadsPerWarp()[d] * - resultLayout.getWarpsPerCTA()[d]; + unsigned resultShapePerCTA = triton::gpu::getSizePerThread(resultLayout)[d] * + triton::gpu::getThreadsPerWarp(resultLayout)[d] * + triton::gpu::getWarpsPerCTA(resultLayout)[d]; int64_t numCtas = ceil(resultShape[d], resultShapePerCTA); if (srcShape[d] != resultShape[d]) { assert(srcShape[d] == 1); broadcastDims.push_back(d); srcLogicalShape[d] = 1; srcLogicalShape[d + rank] = - std::max(1, srcLayout.getSizePerThread()[d]); + std::max(1, triton::gpu::getSizePerThread(srcLayout)[d]); } else { srcLogicalShape[d] = numCtas; - srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; + srcLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d]; } resultLogicalShape[d] = numCtas; - resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; + resultLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d]; srcLogicalOrder[d] = order[d] + rank; srcLogicalOrder[d + rank] = order[d]; diff --git a/python/src/triton.cc b/python/src/triton.cc index 5a0d25732..f07d4cfee 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -163,7 +163,19 @@ void init_triton_ir(py::module &&m) { py::class_(m, "type") .def("is_integer", &mlir::Type::isInteger) - .def("is_fp16", &mlir::Type::isF16); + .def("is_fp16", &mlir::Type::isF16) + .def("__str__", [](mlir::Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type") + .def("param_types", [](mlir::FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); py::class_(m, "value") .def("set_attr", @@ -314,7 +326,14 @@ void init_triton_ir(py::module &&m) { .def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp { return self.lookupSymbol(funcName); - }); + }) + .def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp { + llvm::SmallVector funcs; + self.walk([&](mlir::FuncOp func) { funcs.push_back(func); }); + if (funcs.size() != 1) + throw std::runtime_error("Expected a single function"); + return funcs[0]; + }); m.def( "parse_mlir_module", @@ -363,6 +382,7 @@ void init_triton_ir(py::module &&m) { self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val)); }, ret::reference) + .def_property_readonly("type", &mlir::FuncOp::getType) .def("reset_type", &mlir::FuncOp::setType); py::class_(m, "InsertPoint"); @@ -1274,8 +1294,8 @@ void init_triton_ir(py::module &&m) { void init_triton_translation(py::module &m) { using ret = py::return_value_policy; - m.def("get_shared_memory_size", [](mlir::ModuleOp module) { - auto shared = module->getAttrOfType("triton_gpu.shared"); + m.def("get_shared_memory_size", [](mlir::ModuleOp mod) { + auto shared = mod->getAttrOfType("triton_gpu.shared"); return shared.getInt(); }); diff --git a/python/tests/test_backend.py b/python/tests/test_backend.py new file mode 100644 index 000000000..06f36e43b --- /dev/null +++ b/python/tests/test_backend.py @@ -0,0 +1,91 @@ +import triton +import triton.language as tl +import torch +import pytest +from .test_core import numpy_random, to_triton + +class MmaLayout: + def __init__(self, version, warps_per_cta): + self.version = version + self.warps_per_cta = str(warps_per_cta) + + def __str__(self): + return f"#triton_gpu.mma<{{version={self.version}, warpsPerCTA={self.warps_per_cta}}}>" + +class BlockedLayout: + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): + self.sz_per_thread = str(size_per_thread) + self.threads_per_warp = str(threads_per_warp) + self.warps_per_cta = str(warps_per_cta) + self.order = str(order) + + def __str__(self): + return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" + +layouts = [ + # MmaLayout(version=1, warps_per_cta=[1, 4]), + MmaLayout(version=2, warps_per_cta=[1, 4]), + # MmaLayout(version=1, warps_per_cta=[4, 1]), + MmaLayout(version=2, warps_per_cta=[4, 1]), + BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]), + BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]), + BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]), + BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]), + BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]), + BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]) +] + + +@pytest.mark.parametrize("shape", [(128, 128)]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("dst_layout", layouts) +def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): + if str(src_layout) == str(dst_layout): + pytest.skip() + if 'mma' in str(src_layout) and 'mma' in str(dst_layout): + pytest.skip() + + + + ir = f""" +#src = {src_layout} +#dst = {dst_layout} +""" + """ +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<128> : tensor<128x1xi32, #src> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>> + %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #src> + %4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<128x1xi32, #src> + %6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src> + %7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src> + %8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src> + %9 = arith.addi %8, %7 : tensor<128x128xi32, #src> + %10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr, #src> + %11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src> + %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #dst> + %12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> + %13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> + %14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr, #dst> + tt.store %14, %13 : tensor<128x128xf16, #dst> + return + } +} +""" + + x = to_triton(numpy_random(shape, dtype_str=dtype)) + z = torch.empty_like(x) + + # write the IR to a temporary file using mkstemp + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1,1,1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + diff --git a/python/tests/test_compiler.py b/python/tests/test_compiler.py index 3c34eae93..ebbcac1e2 100644 --- a/python/tests/test_compiler.py +++ b/python/tests/test_compiler.py @@ -16,7 +16,7 @@ def test_empty_kernel_cubin_compile(): device = torch.cuda.current_device() kernel = triton.compile(empty_kernel, - "*fp32,i32,i32", + signature="*fp32,i32,i32", device=device, constants={"BLOCK": 256}) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index b31273814..8406be4c8 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1028,7 +1028,7 @@ def binary_name_to_header_name(name): return f"{name}.h" -def generate_launcher(identifier, constants, signature): +def generate_launcher(constants, signature): arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): @@ -1184,6 +1184,9 @@ class CacheManager: def put(self, data, filename, binary=True): if not self.cache_dir: return + binary = isinstance(data, bytes) + if not binary: + data = str(data) assert self.lock_path is not None filepath = self._make_path(filename) with FileLock(self.lock_path): @@ -1296,16 +1299,8 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata, cache_manager.put(data, file_name, True if isinstance(data, bytes) else data) return module, md5, True, False - -def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} - # we get the kernel, i.e. the first function generated in the module - if configs is None: - configs = [instance_descriptor()] - assert len(configs) == 1 - # cache manager - name = fn.__name__ +# +def make_stub(name, signature, constants): # name of files that are cached so_cache_key = make_so_cache_key(signature, constants) so_cache_manager = CacheManager(so_cache_key) @@ -1313,57 +1308,129 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i # retrieve stub from cache if it exists if not so_cache_manager.has_file(so_name): with tempfile.TemporaryDirectory() as tmpdir: - src = generate_launcher(name, constants, signature) + src = generate_launcher(constants, signature) src_path = os.path.join(tmpdir, "main.c") with open(src_path, "w") as f: f.write(src) - so = _build(fn.__name__, src_path, tmpdir) + so = _build(name, src_path, tmpdir) with open(so, "rb") as f: so_cache_manager.put(f.read(), so_name, binary=True) - so_path = so_cache_manager._make_path(so_name) + return so_cache_manager._make_path(so_name) + + +def convert_type_repr(x): + match = re.search('!tt\.ptr<(.*)>', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + +def make_hash(fn, **kwargs): + if isinstance(fn, triton.runtime.JITFunction): + configs = kwargs["configs"] + signature = kwargs["signature"] + constants = kwargs.get("constants", dict()) + num_warps = kwargs.get("num_warps", 4) + num_stages = kwargs.get("num_stages", 3) + # Get unique key for the compiled code + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) + configs_key = [get_conf_key(conf) for conf in configs] + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}" + return hashlib.md5(key.encode("utf-8")).hexdigest() + assert isinstance(fn, str) + return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest() + + + +# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): +def compile(fn, **kwargs): + # we get the kernel, i.e. the first function generated in the module + # if fn is not a JITFunction, then it + # has to be a path to a file + context = _triton.ir.context() + asm, md5 = dict(), dict() + constants = kwargs.get("constants", dict()) + if isinstance(fn, triton.runtime.JITFunction): + configs = kwargs.get("configs", None) + signature = kwargs["signature"] + if configs is None: + configs = [instance_descriptor()] + assert len(configs) == 1 + kwargs["configs"] = configs + name = fn.__name__ + first_stage = 0 + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + kwargs["signature"] = signature + else: + assert isinstance(fn, str) + name, ir = os.path.basename(fn).split(".") + assert ir == "ttgir" + asm[ir] = _triton.ir.parse_mlir_module(fn, context) + function = asm[ir].get_single_function() + param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()] + signature = {k: v for k, v in enumerate(param_tys)} + first_stage = 2 + + # cache manager + so_path = make_stub(name, signature, constants) # create cache manager - fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages) - fn_cache_manager = CacheManager(fn_cache_key) + fn_cache_manager = CacheManager(make_hash(fn, **kwargs)) + # determine name and extension type of provided function + if isinstance(fn, triton.runtime.JITFunction): + name, ext = fn.__name__, "ast" + else: + name, ext = os.path.basename(fn).split(".") + # initialize compilation params + num_warps = kwargs.get("num_warps", 4) + num_stages = kwargs.get("num_stages", 3) + extern_libs = kwargs.get("extern_libs", dict()) + device = kwargs.get("device", torch.cuda.current_device()) # load metadata if any metadata = None if fn_cache_manager.has_file(f'{name}.json'): with open(fn_cache_manager._make_path(f"{name}.json")) as f: metadata = json.load(f) - context = _triton.ir.context() - force_compile = False - # ast -> triton-ir (or read from cache) - ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata, - run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context), - run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants)) - # triton-ir -> triton-gpu-ir (or read from cache) - ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata, - run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context), - run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages)) - # triton-gpu-ir -> llvm-ir (or read from cache) - llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata, - run_if_found = lambda path: Path(path).read_bytes(), - run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs)) - if llvm_cached: - shmem_size = metadata["shared"] else: - shmem_size = _triton.get_shared_memory_size(ttgir) - # llvm-ir -> ptx (or read from cache) - ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata, - run_if_found = lambda path: Path(path).read_text(), - run_if_not_found = lambda: llir_to_ptx(llir)) - # ptx -> cubin (or read from cache) - cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata, - run_if_found = lambda path: Path(path).read_bytes(), - run_if_not_found= lambda: ptx_to_cubin(ptx, device)) - # dump new metadata - kernel_name = ptx_get_kernel_name(ptx) - metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages, - "md5": { "cubin": cubin_md5, "ptx": ptx_md5, - "llir": llir_md5, - "ttir": ttir_md5, "ttgir": ttgir_md5 }} + metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()} + # build compilation stages + stages = { + "ast" : (lambda path: fn, None), + "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), + lambda src: ast_to_ttir(src, signature, configs[0], constants)), + "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), + lambda src: ttir_to_ttgir(src, num_warps, num_stages)), + "llir": (lambda path: Path(path).read_bytes(), + lambda src: ttgir_to_llir(src, extern_libs)), + "ptx": (lambda path: Path(path).read_text(), + llir_to_ptx), + "cubin": (lambda path: Path(path).read_bytes(), + lambda src: ptx_to_cubin(src, device)) + } + first_stage = list(stages.keys()).index(ext) + asm = dict() + module = fn + # run compilation pipeline and populate metadata + for ir, (parse, compile) in list(stages.items())[first_stage:]: + path = fn_cache_manager._make_path(f"{name}.{ir}") + if ir == ext: + next_module = parse(fn) + elif os.path.exists(path) and\ + os.path.getctime(path) == metadata["ctime"][ir]: + next_module = parse(path) + else: + next_module = compile(module) + fn_cache_manager.put(next_module, f"{name}.{ir}") + if os.path.exists(path): + metadata["ctime"][ir] = os.path.getctime(path) + asm[ir] = next_module if ir == "cubin" else str(next_module) + if ir == "llir" and "shared" not in metadata: + metadata["shared"] = _triton.get_shared_memory_size(module) + if ir == "ptx": + metadata["name"] = ptx_get_kernel_name(next_module) + module = next_module + # write-back metadata fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False) - - asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin} + # return handle to compiled kernel return CompiledKernel(so_path, metadata, asm) @@ -1395,7 +1462,7 @@ class CompiledKernel: if stream is None: stream = torch.cuda.current_stream().cuda_stream self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) - return + return runner def get_sass(self, fun=None): if 'sass' in self.asm: diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index e23b0279f..8d3704fbf 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -275,7 +275,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage raise TypeError(f"Callable constexpr at index {i} is not supported") device = 0 if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): - bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) + bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs) if not warmup: bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args) self.cache[key] = bin