3 Commits

Author SHA1 Message Date
Jokeren
4037f3b921 Add comment 2023-01-05 16:09:44 -05:00
Jokeren
fcff1a6e75 Add comment 2023-01-05 16:09:03 -05:00
Jokeren
2920f6f50f Simple assert 2023-01-05 15:04:08 -05:00
18 changed files with 274 additions and 178 deletions

View File

@@ -222,10 +222,8 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z)
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
target_link_libraries(triton ${LLVM_LIBRARIES} z)
endif()

View File

@@ -408,8 +408,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
// Make PrintfOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
@@ -420,4 +419,14 @@ def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
}];
}
//
// Make AssertOp
//
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
let summary = "Device-side assert, as in CUDA for debugging";
let description = [{}];
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
}
#endif // Triton_OPS

View File

@@ -31,6 +31,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
} // namespace triton
} // namespace mlir

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"

View File

@@ -270,6 +270,48 @@ struct PrintfOpConversion
}
};
struct AssertOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctx = rewriter.getContext();
auto voidTy = void_ty(ctx);
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
Value ret;
for (auto elem : elems) {
auto type = elem.getType();
Value condition;
if (type.isIntOrFloat()) {
if (type.isSignedInteger() || type.isSignlessInteger()) {
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
} else {
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
}
} else {
assert(false && "Unsupported type for assert");
return failure();
}
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
// handled by ptxas
// We should call __assertfail here
// Delete the definition of triton.assert, using mlir.assert instead
PTXBuilder builder;
auto &trapOp = *builder.create<PTXInstr>("trap");
trapOp().predicate(condition);
ret = builder.launch(rewriter, loc, voidTy);
}
rewriter.replaceOp(op, ret);
return success();
}
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
@@ -524,4 +566,5 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}

View File

@@ -45,6 +45,9 @@
#define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs)
#define fcmp_eq(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::oeq, lhs, rhs)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
@@ -77,6 +80,7 @@
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define i1_ty rewriter.getI1Type()
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)

View File

@@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
};
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
using OpConversionPattern<PrintfOp>::OpConversionPattern;
using OpConversionPattern<triton::PrintfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
matchAndRewrite(triton::PrintfOp op,
typename triton::PrintfOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
adaptor.getOperands());
@@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
}
};
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op,
typename triton::AssertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AssertOp>(op, adaptor.condition(),
op.messageAttr());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
@@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context);
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
}
//

View File

@@ -18,7 +18,6 @@
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include <filesystem>
namespace mlir {
namespace triton {
@@ -27,18 +26,19 @@ namespace triton {
// information from mlir module.
struct NVVMMetadata {
int maxntidx{-1};
bool isKernel{};
bool is_kernel{};
// Free to extend with other information.
};
// Add the nvvm related metadata to LLVM IR.
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
auto *module = func->getParent();
auto &ctx = func->getContext();
if (metadata.maxntidx > 0) {
auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32),
llvm::APInt(32, metadata.maxntidx));
auto i32_ty = llvm::IntegerType::get(ctx, 32);
auto warps =
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
llvm::MDString::get(ctx, "maxntidx"),
@@ -48,19 +48,18 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
->addOperand(llvm::MDNode::get(ctx, md_args));
}
if (metadata.isKernel) {
llvm::Metadata *mdArgs[] = {
if (metadata.is_kernel) {
llvm::Metadata *md_args[] = {
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, mdArgs));
->addOperand(llvm::MDNode::get(ctx, md_args));
}
}
static void
extractNVVMMetadata(mlir::ModuleOp module,
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
void extractNVVMMetadata(mlir::ModuleOp module,
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
NVVMMetadata meta;
@@ -75,7 +74,7 @@ extractNVVMMetadata(mlir::ModuleOp module,
// kernel
if (op->hasAttr("nvvm.kernel")) {
meta.isKernel = true;
meta.is_kernel = true;
hasMetadata = true;
}
@@ -84,109 +83,13 @@ extractNVVMMetadata(mlir::ModuleOp module,
}
}
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal())
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) {
std::string libName = name.str();
externLibs[libName] = path.str();
}
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}
}
if (!funcs.empty()) {
// When using the Math Dialect, it is possible that some ops (e.g., log) are
// lowered to a function call. In this case, we need to link libdevice
// using its default path:
// [triton root dir]/python/triton/language/libdevice.10.bc
// TODO(Keren): handle external linkage other than libdevice?
namespace fs = std::filesystem;
static const std::string libdevice = "libdevice";
static const std::filesystem::path path = std::filesystem::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
externLibs.try_emplace(libdevice, path.string());
}
return externLibs;
}
static void linkLibdevice(llvm::Module &module) {
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *i32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module.addModuleFlag(reflect);
}
static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
if (name == "libdevice") {
linkLibdevice(module);
} else {
assert(false && "unknown extern lib: ");
}
return false;
}
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
module->getContext()->appendDialectRegistry(registry);
context->appendDialectRegistry(registry);
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
extractNVVMMetadata(module, &nvvmMetadata);
@@ -197,20 +100,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
return nullptr;
}
// Link external libraries before perform optimizations
// Note from libdevice users guide:
// https://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html
// The standard process for linking with libdevice is to first link it with
// the target module, then run the standard LLVM optimization and code
// generation passes. This allows the optimizers to inline and perform
// analyses on the used library functions, and eliminate any used functions as
// dead code.
auto externLibs = getExternLibs(module);
for (auto &lib : externLibs) {
if (linkExternLib(*llvmModule, lib.first, lib.second))
return nullptr;
}
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
@@ -258,12 +147,49 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmIR) {
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal())
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) {
std::string lib_name = name.str();
externLibs[lib_name] = path.str();
}
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}
}
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;
}
return llvmIR;
llvm::SMDiagnostic err;
for (auto &lib : externLibs) {
if (linkExternLib(*llvmir, lib.second))
return nullptr;
}
return llvmir;
}
void addExternalLibs(mlir::ModuleOp &module,
@@ -285,5 +211,27 @@ void addExternalLibs(mlir::ModuleOp &module,
module.getOperation()->setAttr("triton_gpu.externs", dict);
}
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
return false;
}
} // namespace triton
} // namespace mlir

View File

@@ -8,6 +8,7 @@
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include <filesystem>
namespace triton {
@@ -30,29 +31,68 @@ static bool findAndReplace(std::string &str, const std::string &begin,
return true;
}
static void linkExternal(llvm::Module &module) {
bool hasExternal = false;
for (auto &func : module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
}
}
if (hasExternal) {
namespace fs = std::filesystem;
// [triton root dir]/python/triton/language/libdevice.10.bc
static const fs::path libdevice = fs::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module.addModuleFlag(reflect);
}
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
// LLVM version in use may not officially support target hardware.
// Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
int maxPTX = std::min(75, version);
int maxCC = std::min(86, cc);
linkExternal(module);
// LLVM version in use may not officially support target hardware
int maxNNVMCC = 75;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
std::string sm = "sm_" + std::to_string(maxCC);
// compute capability
std::string sm = "sm_" + std::to_string(cc);
// max PTX version
int ptxMajor = maxPTX / 10;
int ptxMinor = maxPTX % 10;
int ptxMajor = version / 10;
int ptxMinor = version % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(maxCC);
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(maxPTX);
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
initLLVM();
// verify and store llvm
llvm::legacy::PassManager pm;

View File

@@ -15,5 +15,5 @@ def kernel(X, stride_xm,
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

View File

@@ -173,7 +173,7 @@ setup(
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
install_requires=[
"cmake",
"filelock",

View File

@@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) {
llvm::StringRef(prefix)),
values);
})
.def("create_assert",
[](mlir::OpBuilder &self, mlir::Value &condition,
const std::string &message) -> void {
auto loc = self.getUnknownLoc();
auto messageAttr = mlir::StringAttr::get(self.getContext(),
llvm::StringRef(message));
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr);
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {

View File

@@ -52,5 +52,21 @@ def printf(data_type):
assert_close(y, x)
printf("float16")
printf("int8")
def assert2(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.assert2(x == 0, "x > 0")
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
#printf("float16")
#printf("int8")
assert2("float16")

View File

@@ -1267,7 +1267,7 @@ def test_arange(start, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@@ -1286,18 +1286,18 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = GENERATE_TEST_HERE
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
kernel[(1,)](input, output, input_size, output_size)
_kernel[(1,)](input, output, input_size, output_size)
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
reference_out = input
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.

View File

@@ -734,6 +734,10 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.values) == 2
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.And: 'logical_and',
@@ -963,12 +967,23 @@ def ptx_get_version(cuda_version) -> int:
'''
assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.'))
if major == 12:
return 80 + minor
if major == 11:
return 70 + minor
if major == 10:
return 63 + minor
version = major * 1000 + minor * 10
if version >= 11040:
return 74
if version >= 11030:
return 73
if version >= 11020:
return 72
if version >= 11010:
return 71
if version >= 11000:
return 70
if version >= 10020:
return 65
if version >= 10010:
return 64
if version >= 10000:
return 63
raise RuntimeError("Triton only support CUDA 10.0 or higher")

View File

@@ -11,6 +11,7 @@ from .core import (
arange,
argmin,
argmax,
assert2,
atomic_add,
atomic_and,
atomic_cas,
@@ -98,6 +99,7 @@ __all__ = [
"arange",
"argmin",
"argmax",
"assert2",
"atomic_add",
"atomic_and",
"atomic_cas",

View File

@@ -403,18 +403,6 @@ class constexpr:
def __neg__(self):
return constexpr(-self.value)
def __and__(self, other):
return constexpr(self.value & other.value)
def logical_and(self, other):
return constexpr(self.value and other.value)
def __or__(self, other):
return constexpr(self.value | other.value)
def logical_or(self, other):
return constexpr(self.value or other.value)
def __pos__(self):
return constexpr(+self.value)
@@ -830,9 +818,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
'type cache_modifier: str, optional
"""
# mask, other can be constexpr
if _constexpr_to_value(mask) is not None:
if mask is not None:
mask = _to_tensor(mask, _builder)
if _constexpr_to_value(other) is not None:
if other is not None:
other = _to_tensor(other, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
@@ -856,7 +844,7 @@ def store(pointer, value, mask=None, _builder=None):
"""
# value can be constexpr
value = _to_tensor(value, _builder)
if _constexpr_to_value(mask) is not None:
if mask is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, _builder)
@@ -1265,3 +1253,9 @@ def printf(prefix, *args, _builder=None):
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)
@builtin
def assert2(cond, msg="", _builder=None):
msg = _constexpr_to_value(msg)
return semantic.assert2(_to_tensor(cond, _builder), msg, _builder)

View File

@@ -1116,13 +1116,11 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.mulhi(x, y, _builder=builder)
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.floor(x, _builder=builder)
@@ -1172,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)