10 Commits

Author SHA1 Message Date
Da Yan
0f5c6e619c [BUILD] Add the missing triton/impl to setup.py (#1042) 2023-01-09 19:03:45 +00:00
Connor Baker
c20215dad1 [FRONTEND] Update PTX/SM support for LLVM14 (PR #1038 redux) (#1039)
=
2023-01-09 10:31:55 -08:00
Keren Zhou
733301ff31 [Backend] Rewrite code for linking external library to expose more inlining opportunities (#1037)
- Also make it cleaner. 
- And mark out the code needs to be fixed in `semantic.py`.
2023-01-08 13:44:29 -08:00
Shintaro Iwasaki
ff399fbc20 [Build] Support GCC 8.x to build Triton (#1036) 2023-01-06 19:36:14 -08:00
Keren Zhou
4023149ee3 [Frontend] Convert constexpr to value for store and load ops (#1030)
Fixing problem 2 in https://github.com/openai/triton/issues/1017

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-01-05 14:40:16 -05:00
Gregory Axler
2193bee94e [Example] Fix the compile function in copy_strided.py (#1029) 2023-01-05 10:37:41 -08:00
Sophia Wisdom
411bacb2a8 [FRONTEND] Add logical operations on constexprs (#1033) 2023-01-04 18:06:32 -08:00
Sharad Vikram
bc73bbb12c [FRONTEND] Fix argmin/max output type (#1012)
Currently Triton returns tensors with the input types rather than i32
when doing reduce argmax/argmin.
2023-01-03 23:12:16 -08:00
Keren Zhou
8460ea3df1 [Frontend] Fix import for libdevice (#1028)
This is a hotfix for issue 1 in
https://github.com/openai/triton/issues/1017
2023-01-03 15:48:05 -08:00
Keren Zhou
678b9f53a2 [Backend] Use post-order traversal for liveness numbering (#1027)
Also add tests for `tt.trans`.
2023-01-03 15:11:54 -08:00
17 changed files with 247 additions and 169 deletions

View File

@@ -222,8 +222,10 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
else()
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z)
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
endif()

View File

@@ -289,7 +289,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> {
SameOperandsAndResultElementType]> {
let summary = "transpose a tensor";

View File

@@ -31,8 +31,6 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
} // namespace triton
} // namespace mlir

View File

@@ -25,13 +25,14 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
if (maybeSharedAllocationOp(op)) {
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// FIXME(Keren): extract and insert are always alias for now
// XXX(Keren): the following ops are always aliasing for now
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
// trans %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (isa<tensor::InsertSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
op)) {
// insert_slice_async %src, %dst, %index
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());

View File

@@ -298,10 +298,24 @@ private:
/// Resolves liveness of all values involved under the root operation.
void resolveLiveness() {
// In the SCF dialect, we always have a sequentially nested structure of
// blocks
// Assign an ID to each operation using post-order traversal.
// To achieve the correct liveness range, the parent operation's ID
// should be greater than each of its child operation's ID .
// Example:
// ...
// %5 = triton.convert_layout %4
// %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) {
// %2 = triton.convert_layout %5
// ...
// scf.yield %arg0
// }
// For example, %5 is defined in the parent region and used in
// the child region, and is not passed as a block argument.
// %6 should should have an ID greater than its child operations,
// otherwise %5 liveness range ends before the child operation's liveness
// range ends.
DenseMap<Operation *, size_t> operationId;
operation->walk<WalkOrder::PreOrder>(
operation->walk<WalkOrder::PostOrder>(
[&](Operation *op) { operationId[op] = operationId.size(); });
// Analyze liveness of explicit buffers

View File

@@ -18,6 +18,7 @@
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include <filesystem>
namespace mlir {
namespace triton {
@@ -26,19 +27,18 @@ namespace triton {
// information from mlir module.
struct NVVMMetadata {
int maxntidx{-1};
bool is_kernel{};
bool isKernel{};
// Free to extend with other information.
};
// Add the nvvm related metadata to LLVM IR.
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
auto *module = func->getParent();
auto &ctx = func->getContext();
if (metadata.maxntidx > 0) {
auto i32_ty = llvm::IntegerType::get(ctx, 32);
auto warps =
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32),
llvm::APInt(32, metadata.maxntidx));
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
llvm::MDString::get(ctx, "maxntidx"),
@@ -48,18 +48,19 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
->addOperand(llvm::MDNode::get(ctx, md_args));
}
if (metadata.is_kernel) {
llvm::Metadata *md_args[] = {
if (metadata.isKernel) {
llvm::Metadata *mdArgs[] = {
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, md_args));
->addOperand(llvm::MDNode::get(ctx, mdArgs));
}
}
void extractNVVMMetadata(mlir::ModuleOp module,
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
static void
extractNVVMMetadata(mlir::ModuleOp module,
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
NVVMMetadata meta;
@@ -74,7 +75,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
// kernel
if (op->hasAttr("nvvm.kernel")) {
meta.is_kernel = true;
meta.isKernel = true;
hasMetadata = true;
}
@@ -83,13 +84,109 @@ void extractNVVMMetadata(mlir::ModuleOp module,
}
}
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal())
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) {
std::string libName = name.str();
externLibs[libName] = path.str();
}
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}
}
if (!funcs.empty()) {
// When using the Math Dialect, it is possible that some ops (e.g., log) are
// lowered to a function call. In this case, we need to link libdevice
// using its default path:
// [triton root dir]/python/triton/language/libdevice.10.bc
// TODO(Keren): handle external linkage other than libdevice?
namespace fs = std::filesystem;
static const std::string libdevice = "libdevice";
static const std::filesystem::path path = std::filesystem::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
externLibs.try_emplace(libdevice, path.string());
}
return externLibs;
}
static void linkLibdevice(llvm::Module &module) {
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *i32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module.addModuleFlag(reflect);
}
static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
if (name == "libdevice") {
linkLibdevice(module);
} else {
assert(false && "unknown extern lib: ");
}
return false;
}
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
context->appendDialectRegistry(registry);
module->getContext()->appendDialectRegistry(registry);
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
extractNVVMMetadata(module, &nvvmMetadata);
@@ -100,6 +197,20 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
return nullptr;
}
// Link external libraries before perform optimizations
// Note from libdevice users guide:
// https://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html
// The standard process for linking with libdevice is to first link it with
// the target module, then run the standard LLVM optimization and code
// generation passes. This allows the optimizers to inline and perform
// analyses on the used library functions, and eliminate any used functions as
// dead code.
auto externLibs = getExternLibs(module);
for (auto &lib : externLibs) {
if (linkExternLib(*llvmModule, lib.first, lib.second))
return nullptr;
}
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
@@ -147,49 +258,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal())
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) {
std::string lib_name = name.str();
externLibs[lib_name] = path.str();
}
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}
}
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmir) {
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;
}
llvm::SMDiagnostic err;
for (auto &lib : externLibs) {
if (linkExternLib(*llvmir, lib.second))
return nullptr;
}
return llvmir;
return llvmIR;
}
void addExternalLibs(mlir::ModuleOp &module,
@@ -211,27 +285,5 @@ void addExternalLibs(mlir::ModuleOp &module,
module.getOperation()->setAttr("triton_gpu.externs", dict);
}
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
return false;
}
} // namespace triton
} // namespace mlir

View File

@@ -8,7 +8,6 @@
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include <filesystem>
namespace triton {
@@ -31,68 +30,29 @@ static bool findAndReplace(std::string &str, const std::string &begin,
return true;
}
static void linkExternal(llvm::Module &module) {
bool hasExternal = false;
for (auto &func : module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
}
}
if (hasExternal) {
namespace fs = std::filesystem;
// [triton root dir]/python/triton/language/libdevice.10.bc
static const fs::path libdevice = fs::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module.addModuleFlag(reflect);
}
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
linkExternal(module);
// LLVM version in use may not officially support target hardware
int maxNNVMCC = 75;
// LLVM version in use may not officially support target hardware.
// Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
int maxPTX = std::min(75, version);
int maxCC = std::min(86, cc);
// options
auto options = llvm::cl::getRegisteredOptions();
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(cc);
std::string sm = "sm_" + std::to_string(maxCC);
// max PTX version
int ptxMajor = version / 10;
int ptxMinor = version % 10;
int ptxMajor = maxPTX / 10;
int ptxMinor = maxPTX % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
std::string proc = "sm_" + std::to_string(maxCC);
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
// std::string features = "+ptx" + std::to_string(maxPTX);
initLLVM();
// verify and store llvm
llvm::legacy::PassManager pm;

View File

@@ -15,5 +15,5 @@ def kernel(X, stride_xm,
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

View File

@@ -173,7 +173,7 @@ setup(
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
install_requires=[
"cmake",
"filelock",

View File

@@ -1267,7 +1267,7 @@ def test_arange(start, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@@ -1286,18 +1286,18 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
x = GENERATE_TEST_HERE
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
_kernel[(1,)](input, output, input_size, output_size)
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
kernel[(1,)](input, output, input_size, output_size)
reference_out = input
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.

View File

@@ -734,10 +734,6 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.values) == 2
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.And: 'logical_and',
@@ -967,23 +963,12 @@ def ptx_get_version(cuda_version) -> int:
'''
assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.'))
version = major * 1000 + minor * 10
if version >= 11040:
return 74
if version >= 11030:
return 73
if version >= 11020:
return 72
if version >= 11010:
return 71
if version >= 11000:
return 70
if version >= 10020:
return 65
if version >= 10010:
return 64
if version >= 10000:
return 63
if major == 12:
return 80 + minor
if major == 11:
return 70 + minor
if major == 10:
return 63 + minor
raise RuntimeError("Triton only support CUDA 10.0 or higher")

View File

@@ -5,6 +5,7 @@ from ..impl import (
ir,
builtin,
)
from . import libdevice
from .core import (
abs,
arange,
@@ -130,6 +131,7 @@ __all__ = [
"int64",
"int8",
"ir",
"libdevice",
"load",
"log",
"max",

View File

@@ -403,6 +403,18 @@ class constexpr:
def __neg__(self):
return constexpr(-self.value)
def __and__(self, other):
return constexpr(self.value & other.value)
def logical_and(self, other):
return constexpr(self.value and other.value)
def __or__(self, other):
return constexpr(self.value | other.value)
def logical_or(self, other):
return constexpr(self.value or other.value)
def __pos__(self):
return constexpr(+self.value)
@@ -818,9 +830,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
'type cache_modifier: str, optional
"""
# mask, other can be constexpr
if mask is not None:
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
if other is not None:
if _constexpr_to_value(other) is not None:
other = _to_tensor(other, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
@@ -844,7 +856,7 @@ def store(pointer, value, mask=None, _builder=None):
"""
# value can be constexpr
value = _to_tensor(value, _builder)
if mask is not None:
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, _builder)

View File

@@ -1057,6 +1057,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# If we are doing an argmin or argmax we want to use an int32 output type
out_scalar_ty = scalar_ty
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
out_scalar_ty = tl.int32
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
out_scalar_ty = tl.int32
# get result type
shape = input.type.shape
ret_shape = []
@@ -1064,10 +1071,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if i != axis:
ret_shape.append(s)
if ret_shape:
res_ty = tl.block_type(scalar_ty, ret_shape)
res_ty = tl.block_type(out_scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
res_ty = out_scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
@@ -1109,11 +1116,13 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.mulhi(x, y, _builder=builder)
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.floor(x, _builder=builder)

View File

@@ -52,6 +52,15 @@ func @convert(%A : !tt.ptr<f16>) {
return
}
// CHECK-LABEL: trans
func @trans(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
// CHECK: %0 -> %cst
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>

View File

@@ -174,6 +174,14 @@ func @scratch() {
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: trans
func @trans(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
@@ -285,6 +293,25 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
// CHECK-NEXT: size = 24576
}
// c0 cannot be released in the loop
// CHECK-LABEL: for_use_ancestor
func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 16384, size = 8192
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%c0 = tt.trans %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<32x128xf16, #A_SHARED>
// CHECK-NEXT: offset = 24576, size = 8192
%c1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
scf.yield %b_shared, %a_shared: tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
// CHECK-NEXT: size = 32768
}
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
// CHECK-LABEL: for_if_for

View File

@@ -111,6 +111,13 @@ func @extract_slice() {
return
}
// CHECK-LABEL: trans
func @trans() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>