Compare commits
1 Commits
master
...
phil/swizz
Author | SHA1 | Date | |
---|---|---|---|
|
08366b2d59 |
@@ -222,10 +222,8 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
||||
elseif(APPLE)
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
||||
else()
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
||||
endif()
|
||||
|
||||
|
||||
|
@@ -289,7 +289,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
}
|
||||
|
||||
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
let summary = "transpose a tensor";
|
||||
|
||||
|
@@ -31,6 +31,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
|
@@ -25,14 +25,13 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
if (maybeSharedAllocationOp(op)) {
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// XXX(Keren): the following ops are always aliasing for now
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
// extract_slice %src
|
||||
// trans %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
|
||||
op)) {
|
||||
} else if (isa<tensor::InsertSliceOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||
// insert_slice_async %src, %dst, %index
|
||||
// insert_slice %src into %dst[%offsets]
|
||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||
|
@@ -298,24 +298,10 @@ private:
|
||||
|
||||
/// Resolves liveness of all values involved under the root operation.
|
||||
void resolveLiveness() {
|
||||
// Assign an ID to each operation using post-order traversal.
|
||||
// To achieve the correct liveness range, the parent operation's ID
|
||||
// should be greater than each of its child operation's ID .
|
||||
// Example:
|
||||
// ...
|
||||
// %5 = triton.convert_layout %4
|
||||
// %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) {
|
||||
// %2 = triton.convert_layout %5
|
||||
// ...
|
||||
// scf.yield %arg0
|
||||
// }
|
||||
// For example, %5 is defined in the parent region and used in
|
||||
// the child region, and is not passed as a block argument.
|
||||
// %6 should should have an ID greater than its child operations,
|
||||
// otherwise %5 liveness range ends before the child operation's liveness
|
||||
// range ends.
|
||||
// In the SCF dialect, we always have a sequentially nested structure of
|
||||
// blocks
|
||||
DenseMap<Operation *, size_t> operationId;
|
||||
operation->walk<WalkOrder::PostOrder>(
|
||||
operation->walk<WalkOrder::PreOrder>(
|
||||
[&](Operation *op) { operationId[op] = operationId.size(); });
|
||||
|
||||
// Analyze liveness of explicit buffers
|
||||
|
@@ -18,7 +18,6 @@
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include <filesystem>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
@@ -27,18 +26,19 @@ namespace triton {
|
||||
// information from mlir module.
|
||||
struct NVVMMetadata {
|
||||
int maxntidx{-1};
|
||||
bool isKernel{};
|
||||
bool is_kernel{};
|
||||
// Free to extend with other information.
|
||||
};
|
||||
|
||||
// Add the nvvm related metadata to LLVM IR.
|
||||
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
||||
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
||||
auto *module = func->getParent();
|
||||
auto &ctx = func->getContext();
|
||||
|
||||
if (metadata.maxntidx > 0) {
|
||||
auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32),
|
||||
llvm::APInt(32, metadata.maxntidx));
|
||||
auto i32_ty = llvm::IntegerType::get(ctx, 32);
|
||||
auto warps =
|
||||
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
|
||||
|
||||
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
|
||||
llvm::MDString::get(ctx, "maxntidx"),
|
||||
@@ -48,19 +48,18 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
||||
->addOperand(llvm::MDNode::get(ctx, md_args));
|
||||
}
|
||||
|
||||
if (metadata.isKernel) {
|
||||
llvm::Metadata *mdArgs[] = {
|
||||
if (metadata.is_kernel) {
|
||||
llvm::Metadata *md_args[] = {
|
||||
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
||||
llvm::ValueAsMetadata::get(
|
||||
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")
|
||||
->addOperand(llvm::MDNode::get(ctx, mdArgs));
|
||||
->addOperand(llvm::MDNode::get(ctx, md_args));
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
extractNVVMMetadata(mlir::ModuleOp module,
|
||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
|
||||
void extractNVVMMetadata(mlir::ModuleOp module,
|
||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
|
||||
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
|
||||
NVVMMetadata meta;
|
||||
|
||||
@@ -75,7 +74,7 @@ extractNVVMMetadata(mlir::ModuleOp module,
|
||||
|
||||
// kernel
|
||||
if (op->hasAttr("nvvm.kernel")) {
|
||||
meta.isKernel = true;
|
||||
meta.is_kernel = true;
|
||||
hasMetadata = true;
|
||||
}
|
||||
|
||||
@@ -84,109 +83,13 @@ extractNVVMMetadata(mlir::ModuleOp module,
|
||||
}
|
||||
}
|
||||
|
||||
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
std::map<std::string, std::string> externLibs;
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
module.walk([&](LLVM::LLVMFuncOp func) {
|
||||
if (func.isExternal())
|
||||
funcs.push_back(func);
|
||||
});
|
||||
|
||||
for (auto &func : funcs) {
|
||||
if (func.getOperation()->hasAttr("libname")) {
|
||||
auto name =
|
||||
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
||||
auto path =
|
||||
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||
if (name) {
|
||||
std::string libName = name.str();
|
||||
externLibs[libName] = path.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
||||
auto dict = module.getOperation()
|
||||
->getAttr("triton_gpu.externs")
|
||||
.dyn_cast<DictionaryAttr>();
|
||||
for (auto &attr : dict) {
|
||||
externLibs[attr.getName().strref().trim().str()] =
|
||||
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||
}
|
||||
}
|
||||
|
||||
if (!funcs.empty()) {
|
||||
// When using the Math Dialect, it is possible that some ops (e.g., log) are
|
||||
// lowered to a function call. In this case, we need to link libdevice
|
||||
// using its default path:
|
||||
// [triton root dir]/python/triton/language/libdevice.10.bc
|
||||
// TODO(Keren): handle external linkage other than libdevice?
|
||||
namespace fs = std::filesystem;
|
||||
static const std::string libdevice = "libdevice";
|
||||
static const std::filesystem::path path = std::filesystem::path(__FILE__)
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path() /
|
||||
"python" / "triton" / "language" /
|
||||
"libdevice.10.bc";
|
||||
externLibs.try_emplace(libdevice, path.string());
|
||||
}
|
||||
|
||||
return externLibs;
|
||||
}
|
||||
|
||||
static void linkLibdevice(llvm::Module &module) {
|
||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||
// this will enable fast math path in libdevice
|
||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||
// sqrt.approx.ftz.f32
|
||||
auto &ctx = module.getContext();
|
||||
llvm::Type *i32 = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::Metadata *mdFour =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 4));
|
||||
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||
llvm::Metadata *mdOne =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 1));
|
||||
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
||||
module.addModuleFlag(reflect);
|
||||
}
|
||||
|
||||
static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
llvm::StringRef path) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto &ctx = module.getContext();
|
||||
|
||||
auto extMod = llvm::parseIRFile(path, err, ctx);
|
||||
if (!extMod) {
|
||||
llvm::errs() << "Failed to load " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
extMod->setTargetTriple(module.getTargetTriple());
|
||||
extMod->setDataLayout(module.getDataLayout());
|
||||
|
||||
if (llvm::Linker::linkModules(module, std::move(extMod),
|
||||
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
||||
llvm::errs() << "Failed to link " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (name == "libdevice") {
|
||||
linkLibdevice(module);
|
||||
} else {
|
||||
assert(false && "unknown extern lib: ");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
auto context = module->getContext();
|
||||
DialectRegistry registry;
|
||||
mlir::registerLLVMDialectTranslation(registry);
|
||||
mlir::registerNVVMDialectTranslation(registry);
|
||||
module->getContext()->appendDialectRegistry(registry);
|
||||
context->appendDialectRegistry(registry);
|
||||
|
||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
|
||||
extractNVVMMetadata(module, &nvvmMetadata);
|
||||
@@ -197,20 +100,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Link external libraries before perform optimizations
|
||||
// Note from libdevice users guide:
|
||||
// https://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html
|
||||
// The standard process for linking with libdevice is to first link it with
|
||||
// the target module, then run the standard LLVM optimization and code
|
||||
// generation passes. This allows the optimizers to inline and perform
|
||||
// analyses on the used library functions, and eliminate any used functions as
|
||||
// dead code.
|
||||
auto externLibs = getExternLibs(module);
|
||||
for (auto &lib : externLibs) {
|
||||
if (linkExternLib(*llvmModule, lib.first, lib.second))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/*optLevel=*/3, /*sizeLevel=*/0,
|
||||
/*targetMachine=*/nullptr);
|
||||
@@ -258,12 +147,49 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
|
||||
if (!llvmIR) {
|
||||
std::map<std::string, std::string> externLibs;
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
module.walk([&](LLVM::LLVMFuncOp func) {
|
||||
if (func.isExternal())
|
||||
funcs.push_back(func);
|
||||
});
|
||||
|
||||
for (auto &func : funcs) {
|
||||
if (func.getOperation()->hasAttr("libname")) {
|
||||
auto name =
|
||||
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
||||
auto path =
|
||||
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||
if (name) {
|
||||
std::string lib_name = name.str();
|
||||
externLibs[lib_name] = path.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
||||
auto dict = module.getOperation()
|
||||
->getAttr("triton_gpu.externs")
|
||||
.dyn_cast<DictionaryAttr>();
|
||||
for (auto &attr : dict) {
|
||||
externLibs[attr.getName().strref().trim().str()] =
|
||||
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||
}
|
||||
}
|
||||
|
||||
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
return nullptr;
|
||||
}
|
||||
return llvmIR;
|
||||
|
||||
llvm::SMDiagnostic err;
|
||||
for (auto &lib : externLibs) {
|
||||
if (linkExternLib(*llvmir, lib.second))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return llvmir;
|
||||
}
|
||||
|
||||
void addExternalLibs(mlir::ModuleOp &module,
|
||||
@@ -285,5 +211,27 @@ void addExternalLibs(mlir::ModuleOp &module,
|
||||
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
||||
}
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto &ctx = module.getContext();
|
||||
|
||||
auto extMod = llvm::parseIRFile(path, err, ctx);
|
||||
if (!extMod) {
|
||||
llvm::errs() << "Failed to load " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
extMod->setTargetTriple(module.getTargetTriple());
|
||||
extMod->setDataLayout(module.getDataLayout());
|
||||
|
||||
if (llvm::Linker::linkModules(module, std::move(extMod),
|
||||
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
||||
llvm::errs() << "Failed to link " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
@@ -8,6 +8,7 @@
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include <filesystem>
|
||||
|
||||
namespace triton {
|
||||
|
||||
@@ -30,29 +31,68 @@ static bool findAndReplace(std::string &str, const std::string &begin,
|
||||
return true;
|
||||
}
|
||||
|
||||
static void linkExternal(llvm::Module &module) {
|
||||
bool hasExternal = false;
|
||||
for (auto &func : module) {
|
||||
if (func.hasExternalLinkage()) {
|
||||
hasExternal = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasExternal) {
|
||||
namespace fs = std::filesystem;
|
||||
// [triton root dir]/python/triton/language/libdevice.10.bc
|
||||
static const fs::path libdevice = fs::path(__FILE__)
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path() /
|
||||
"python" / "triton" / "language" /
|
||||
"libdevice.10.bc";
|
||||
if (mlir::triton::linkExternLib(module, libdevice.string()))
|
||||
llvm::errs() << "link failed for: " << libdevice.string();
|
||||
|
||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||
// this will enable fast math path in libdevice
|
||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||
// sqrt.approx.ftz.f32
|
||||
auto &ctx = module.getContext();
|
||||
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::Metadata *mdFour =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
||||
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||
llvm::Metadata *mdOne =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
||||
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
||||
module.addModuleFlag(reflect);
|
||||
}
|
||||
}
|
||||
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
// LLVM version in use may not officially support target hardware.
|
||||
// Supported versions for LLVM 14 are here:
|
||||
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
|
||||
int maxPTX = std::min(75, version);
|
||||
int maxCC = std::min(86, cc);
|
||||
linkExternal(module);
|
||||
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int maxNNVMCC = 75;
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
auto *shortPtr =
|
||||
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
||||
assert(shortPtr);
|
||||
shortPtr->setValue(true);
|
||||
std::string sm = "sm_" + std::to_string(maxCC);
|
||||
// compute capability
|
||||
std::string sm = "sm_" + std::to_string(cc);
|
||||
// max PTX version
|
||||
int ptxMajor = maxPTX / 10;
|
||||
int ptxMinor = maxPTX % 10;
|
||||
int ptxMajor = version / 10;
|
||||
int ptxMinor = version % 10;
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
std::string triple = "nvptx64-nvidia-cuda";
|
||||
std::string proc = "sm_" + std::to_string(maxCC);
|
||||
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
|
||||
std::string layout = "";
|
||||
std::string features = "";
|
||||
// std::string features = "+ptx" + std::to_string(maxPTX);
|
||||
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
||||
// max_nvvm_ptx));
|
||||
initLLVM();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
|
55
python/chain-dot.ttgir
Normal file
55
python/chain-dot.ttgir
Normal file
@@ -0,0 +1,55 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
|
||||
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
|
||||
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func public @kernel_0d1d2c3d4d5c6d7d8c9d10d11c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma1>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0>
|
||||
%6 = tt.broadcast %5 : (tensor<1x64xi32, #blocked0>) -> tensor<64x64xi32, #blocked0>
|
||||
%7 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%8 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%9 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%10 = tt.splat %arg4 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%11 = tt.splat %arg7 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%12 = tt.splat %arg6 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%13 = arith.muli %1, %2 : tensor<64x1xi32, #blocked0>
|
||||
%14 = tt.addptr %3, %13 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%15 = tt.broadcast %14 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
|
||||
%16 = tt.addptr %15, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
|
||||
%17 = tt.load %16 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0>
|
||||
%18 = arith.muli %1, %7 : tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.addptr %8, %18 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%20 = tt.broadcast %19 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
|
||||
%21 = tt.addptr %20, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
|
||||
%22 = tt.load %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0>
|
||||
%23 = triton_gpu.convert_layout %17 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
|
||||
%24 = triton_gpu.convert_layout %22 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
|
||||
%25 = tt.dot %23, %24, %cst {allowTF32 = false} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<64x64xf32, #mma0>
|
||||
%27 = arith.muli %1, %9 : tensor<64x1xi32, #blocked0>
|
||||
%28 = tt.addptr %10, %27 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%29 = tt.broadcast %28 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
|
||||
%30 = tt.addptr %29, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
|
||||
%31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked0>
|
||||
%32 = arith.truncf %25 : tensor<64x64xf32, #mma0> to tensor<64x64xf16, #mma0>
|
||||
%133 = triton_gpu.convert_layout %32 : (tensor<64x64xf16, #mma0>) -> tensor<64x64xf16, #shared>
|
||||
%33 = triton_gpu.convert_layout %133 : (tensor<64x64xf16, #shared>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||
%34 = triton_gpu.convert_layout %31 : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
|
||||
%35 = tt.dot %33, %34, %cst_0 {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<64x64xf32, #mma1>
|
||||
%36 = triton_gpu.convert_layout %35 : (tensor<64x64xf32, #mma1>) -> tensor<64x64xf32, #blocked0>
|
||||
%37 = arith.muli %1, %11 : tensor<64x1xi32, #blocked0>
|
||||
%38 = tt.addptr %12, %37 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%39 = tt.broadcast %38 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x64x!tt.ptr<f16>, #blocked0>
|
||||
%40 = tt.addptr %39, %6 : tensor<64x64x!tt.ptr<f16>, #blocked0>, tensor<64x64xi32, #blocked0>
|
||||
%41 = arith.truncf %36 : tensor<64x64xf32, #blocked0> to tensor<64x64xf16, #blocked0>
|
||||
tt.store %40, %41 : tensor<64x64xf16, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
@@ -15,5 +15,5 @@ def kernel(X, stride_xm,
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
|
||||
ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
print(ret)
|
||||
|
@@ -173,7 +173,7 @@ setup(
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
|
||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
|
||||
install_requires=[
|
||||
"cmake",
|
||||
"filelock",
|
||||
|
@@ -1177,19 +1177,25 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
# y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
# w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
# z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
# COL_A=col_a, COL_B=col_b,
|
||||
# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
# ADD_MATRIX=epilogue == 'add-matrix',
|
||||
# ADD_ROWS=epilogue == 'add-rows',
|
||||
# ADD_COLS=epilogue == 'add-cols',
|
||||
# DO_SOFTMAX=epilogue == 'softmax',
|
||||
# CHAIN_DOT=epilogue == 'chain-dot',
|
||||
# ALLOW_TF32=allow_tf32,
|
||||
# num_warps=num_warps)
|
||||
kernel = triton.compile("./chain-dot.ttgir", num_warps=num_warps)
|
||||
pgm = kernel[(1, 1, 1)](x_tri.data_ptr(), x_tri.stride(0),
|
||||
y_tri.data_ptr(), y_tri.stride(0),
|
||||
w_tri.data_ptr(), w_tri.stride(0),
|
||||
z_tri.data_ptr(), z_tri.stride(0))
|
||||
|
||||
# torch result
|
||||
if dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
@@ -1217,15 +1223,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
# ptx = pgm.asm['ptx']
|
||||
# assert 'ld.global.v4' in ptx
|
||||
# assert 'st.global.v4' in ptx
|
||||
# if dtype == 'float32' and allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
# elif dtype == 'float32' and allow_tf32:
|
||||
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
# elif dtype == 'int8':
|
||||
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
@@ -1267,7 +1273,7 @@ def test_arange(start, device='cuda'):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
|
||||
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
|
||||
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
@@ -1286,18 +1292,18 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||
in_offsets = tl.arange(0, out_size)
|
||||
# Load inputs.
|
||||
x = GENERATE_TEST_HERE
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
|
||||
# Store output
|
||||
output_offsets = tl.arange(0, out_size)
|
||||
tl.store(out_ptr + output_offsets, x)
|
||||
|
||||
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
|
||||
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
|
||||
kernel[(1,)](input, output, input_size, output_size)
|
||||
_kernel[(1,)](input, output, input_size, output_size)
|
||||
|
||||
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
reference_out = input
|
||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
|
@@ -734,6 +734,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(node.values) == 2
|
||||
lhs = self.visit(node.values[0])
|
||||
rhs = self.visit(node.values[1])
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
|
||||
fn = {
|
||||
ast.And: 'logical_and',
|
||||
@@ -963,12 +967,23 @@ def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
assert isinstance(cuda_version, str)
|
||||
major, minor = map(int, cuda_version.split('.'))
|
||||
if major == 12:
|
||||
return 80 + minor
|
||||
if major == 11:
|
||||
return 70 + minor
|
||||
if major == 10:
|
||||
return 63 + minor
|
||||
version = major * 1000 + minor * 10
|
||||
if version >= 11040:
|
||||
return 74
|
||||
if version >= 11030:
|
||||
return 73
|
||||
if version >= 11020:
|
||||
return 72
|
||||
if version >= 11010:
|
||||
return 71
|
||||
if version >= 11000:
|
||||
return 70
|
||||
if version >= 10020:
|
||||
return 65
|
||||
if version >= 10010:
|
||||
return 64
|
||||
if version >= 10000:
|
||||
return 63
|
||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||
|
||||
|
||||
@@ -1454,9 +1469,7 @@ def compile(fn, **kwargs):
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
print(name, signature)
|
||||
types = re.findall(arg_type_pattern[ir], signature)
|
||||
print(types)
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir)
|
||||
|
@@ -5,7 +5,6 @@ from ..impl import (
|
||||
ir,
|
||||
builtin,
|
||||
)
|
||||
from . import libdevice
|
||||
from .core import (
|
||||
abs,
|
||||
arange,
|
||||
@@ -131,7 +130,6 @@ __all__ = [
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"libdevice",
|
||||
"load",
|
||||
"log",
|
||||
"max",
|
||||
|
@@ -403,18 +403,6 @@ class constexpr:
|
||||
def __neg__(self):
|
||||
return constexpr(-self.value)
|
||||
|
||||
def __and__(self, other):
|
||||
return constexpr(self.value & other.value)
|
||||
|
||||
def logical_and(self, other):
|
||||
return constexpr(self.value and other.value)
|
||||
|
||||
def __or__(self, other):
|
||||
return constexpr(self.value | other.value)
|
||||
|
||||
def logical_or(self, other):
|
||||
return constexpr(self.value or other.value)
|
||||
|
||||
def __pos__(self):
|
||||
return constexpr(+self.value)
|
||||
|
||||
@@ -830,9 +818,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
|
||||
'type cache_modifier: str, optional
|
||||
"""
|
||||
# mask, other can be constexpr
|
||||
if _constexpr_to_value(mask) is not None:
|
||||
if mask is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
if _constexpr_to_value(other) is not None:
|
||||
if other is not None:
|
||||
other = _to_tensor(other, _builder)
|
||||
cache_modifier = _constexpr_to_value(cache_modifier)
|
||||
eviction_policy = _constexpr_to_value(eviction_policy)
|
||||
@@ -856,7 +844,7 @@ def store(pointer, value, mask=None, _builder=None):
|
||||
"""
|
||||
# value can be constexpr
|
||||
value = _to_tensor(value, _builder)
|
||||
if _constexpr_to_value(mask) is not None:
|
||||
if mask is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
return semantic.store(pointer, value, mask, _builder)
|
||||
|
||||
|
@@ -1057,13 +1057,6 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
if INT_OP in int_op_to_unit:
|
||||
INT_OP = int_op_to_unit[INT_OP]
|
||||
|
||||
# If we are doing an argmin or argmax we want to use an int32 output type
|
||||
out_scalar_ty = scalar_ty
|
||||
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
|
||||
out_scalar_ty = tl.int32
|
||||
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
|
||||
out_scalar_ty = tl.int32
|
||||
|
||||
# get result type
|
||||
shape = input.type.shape
|
||||
ret_shape = []
|
||||
@@ -1071,10 +1064,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
if i != axis:
|
||||
ret_shape.append(s)
|
||||
if ret_shape:
|
||||
res_ty = tl.block_type(out_scalar_ty, ret_shape)
|
||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||
else:
|
||||
# 0d-tensor -> scalar
|
||||
res_ty = out_scalar_ty
|
||||
res_ty = scalar_ty
|
||||
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||
@@ -1116,13 +1109,11 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
x, y = binary_op_type_checking_impl(x, y, builder)
|
||||
# FIXME(Keren): not portable, should be fixed
|
||||
from . import libdevice
|
||||
return libdevice.mulhi(x, y, _builder=builder)
|
||||
|
||||
|
||||
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# FIXME(Keren): not portable, should be fixed
|
||||
from . import libdevice
|
||||
return libdevice.floor(x, _builder=builder)
|
||||
|
||||
|
@@ -52,15 +52,6 @@ func @convert(%A : !tt.ptr<f16>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
// CHECK: %0 -> %cst
|
||||
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
|
@@ -174,14 +174,6 @@ func @scratch() {
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 1024
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
@@ -293,25 +285,6 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
|
||||
// CHECK-NEXT: size = 24576
|
||||
}
|
||||
|
||||
// c0 cannot be released in the loop
|
||||
// CHECK-LABEL: for_use_ancestor
|
||||
func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 16384, size = 8192
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
%c0 = tt.trans %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<32x128xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 24576, size = 8192
|
||||
%c1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
scf.yield %b_shared, %a_shared: tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
// CHECK-NEXT: size = 32768
|
||||
}
|
||||
|
||||
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
|
||||
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
|
||||
// CHECK-LABEL: for_if_for
|
||||
|
@@ -111,13 +111,6 @@ func @extract_slice() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func @trans() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
|
Reference in New Issue
Block a user