Compare commits
7 Commits
keren/asse
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
0f5c6e619c | ||
|
c20215dad1 | ||
|
733301ff31 | ||
|
ff399fbc20 | ||
|
4023149ee3 | ||
|
2193bee94e | ||
|
411bacb2a8 |
@@ -222,8 +222,10 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
|||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
||||||
else()
|
elseif(APPLE)
|
||||||
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
||||||
|
else()
|
||||||
|
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@@ -31,8 +31,6 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||||
|
|
||||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
|
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@@ -18,6 +18,7 @@
|
|||||||
#include "llvm/IRReader/IRReader.h"
|
#include "llvm/IRReader/IRReader.h"
|
||||||
#include "llvm/Linker/Linker.h"
|
#include "llvm/Linker/Linker.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace triton {
|
namespace triton {
|
||||||
@@ -26,19 +27,18 @@ namespace triton {
|
|||||||
// information from mlir module.
|
// information from mlir module.
|
||||||
struct NVVMMetadata {
|
struct NVVMMetadata {
|
||||||
int maxntidx{-1};
|
int maxntidx{-1};
|
||||||
bool is_kernel{};
|
bool isKernel{};
|
||||||
// Free to extend with other information.
|
// Free to extend with other information.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add the nvvm related metadata to LLVM IR.
|
// Add the nvvm related metadata to LLVM IR.
|
||||||
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
||||||
auto *module = func->getParent();
|
auto *module = func->getParent();
|
||||||
auto &ctx = func->getContext();
|
auto &ctx = func->getContext();
|
||||||
|
|
||||||
if (metadata.maxntidx > 0) {
|
if (metadata.maxntidx > 0) {
|
||||||
auto i32_ty = llvm::IntegerType::get(ctx, 32);
|
auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32),
|
||||||
auto warps =
|
llvm::APInt(32, metadata.maxntidx));
|
||||||
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
|
|
||||||
|
|
||||||
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
|
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
|
||||||
llvm::MDString::get(ctx, "maxntidx"),
|
llvm::MDString::get(ctx, "maxntidx"),
|
||||||
@@ -48,18 +48,19 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
|||||||
->addOperand(llvm::MDNode::get(ctx, md_args));
|
->addOperand(llvm::MDNode::get(ctx, md_args));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (metadata.is_kernel) {
|
if (metadata.isKernel) {
|
||||||
llvm::Metadata *md_args[] = {
|
llvm::Metadata *mdArgs[] = {
|
||||||
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
||||||
llvm::ValueAsMetadata::get(
|
llvm::ValueAsMetadata::get(
|
||||||
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
|
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
|
||||||
module->getOrInsertNamedMetadata("nvvm.annotations")
|
module->getOrInsertNamedMetadata("nvvm.annotations")
|
||||||
->addOperand(llvm::MDNode::get(ctx, md_args));
|
->addOperand(llvm::MDNode::get(ctx, mdArgs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void extractNVVMMetadata(mlir::ModuleOp module,
|
static void
|
||||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
|
extractNVVMMetadata(mlir::ModuleOp module,
|
||||||
|
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
|
||||||
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
|
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
|
||||||
NVVMMetadata meta;
|
NVVMMetadata meta;
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
|||||||
|
|
||||||
// kernel
|
// kernel
|
||||||
if (op->hasAttr("nvvm.kernel")) {
|
if (op->hasAttr("nvvm.kernel")) {
|
||||||
meta.is_kernel = true;
|
meta.isKernel = true;
|
||||||
hasMetadata = true;
|
hasMetadata = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,13 +84,109 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||||
|
std::map<std::string, std::string> externLibs;
|
||||||
|
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||||
|
module.walk([&](LLVM::LLVMFuncOp func) {
|
||||||
|
if (func.isExternal())
|
||||||
|
funcs.push_back(func);
|
||||||
|
});
|
||||||
|
|
||||||
|
for (auto &func : funcs) {
|
||||||
|
if (func.getOperation()->hasAttr("libname")) {
|
||||||
|
auto name =
|
||||||
|
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
||||||
|
auto path =
|
||||||
|
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||||
|
if (name) {
|
||||||
|
std::string libName = name.str();
|
||||||
|
externLibs[libName] = path.str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
||||||
|
auto dict = module.getOperation()
|
||||||
|
->getAttr("triton_gpu.externs")
|
||||||
|
.dyn_cast<DictionaryAttr>();
|
||||||
|
for (auto &attr : dict) {
|
||||||
|
externLibs[attr.getName().strref().trim().str()] =
|
||||||
|
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!funcs.empty()) {
|
||||||
|
// When using the Math Dialect, it is possible that some ops (e.g., log) are
|
||||||
|
// lowered to a function call. In this case, we need to link libdevice
|
||||||
|
// using its default path:
|
||||||
|
// [triton root dir]/python/triton/language/libdevice.10.bc
|
||||||
|
// TODO(Keren): handle external linkage other than libdevice?
|
||||||
|
namespace fs = std::filesystem;
|
||||||
|
static const std::string libdevice = "libdevice";
|
||||||
|
static const std::filesystem::path path = std::filesystem::path(__FILE__)
|
||||||
|
.parent_path()
|
||||||
|
.parent_path()
|
||||||
|
.parent_path()
|
||||||
|
.parent_path() /
|
||||||
|
"python" / "triton" / "language" /
|
||||||
|
"libdevice.10.bc";
|
||||||
|
externLibs.try_emplace(libdevice, path.string());
|
||||||
|
}
|
||||||
|
|
||||||
|
return externLibs;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void linkLibdevice(llvm::Module &module) {
|
||||||
|
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||||
|
// this will enable fast math path in libdevice
|
||||||
|
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||||
|
// sqrt.approx.ftz.f32
|
||||||
|
auto &ctx = module.getContext();
|
||||||
|
llvm::Type *i32 = llvm::Type::getInt32Ty(ctx);
|
||||||
|
llvm::Metadata *mdFour =
|
||||||
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 4));
|
||||||
|
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||||
|
llvm::Metadata *mdOne =
|
||||||
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 1));
|
||||||
|
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
||||||
|
module.addModuleFlag(reflect);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||||
|
llvm::StringRef path) {
|
||||||
|
llvm::SMDiagnostic err;
|
||||||
|
auto &ctx = module.getContext();
|
||||||
|
|
||||||
|
auto extMod = llvm::parseIRFile(path, err, ctx);
|
||||||
|
if (!extMod) {
|
||||||
|
llvm::errs() << "Failed to load " << path;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
extMod->setTargetTriple(module.getTargetTriple());
|
||||||
|
extMod->setDataLayout(module.getDataLayout());
|
||||||
|
|
||||||
|
if (llvm::Linker::linkModules(module, std::move(extMod),
|
||||||
|
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
||||||
|
llvm::errs() << "Failed to link " << path;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (name == "libdevice") {
|
||||||
|
linkLibdevice(module);
|
||||||
|
} else {
|
||||||
|
assert(false && "unknown extern lib: ");
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<llvm::Module>
|
std::unique_ptr<llvm::Module>
|
||||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||||
auto context = module->getContext();
|
|
||||||
DialectRegistry registry;
|
DialectRegistry registry;
|
||||||
mlir::registerLLVMDialectTranslation(registry);
|
mlir::registerLLVMDialectTranslation(registry);
|
||||||
mlir::registerNVVMDialectTranslation(registry);
|
mlir::registerNVVMDialectTranslation(registry);
|
||||||
context->appendDialectRegistry(registry);
|
module->getContext()->appendDialectRegistry(registry);
|
||||||
|
|
||||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
|
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
|
||||||
extractNVVMMetadata(module, &nvvmMetadata);
|
extractNVVMMetadata(module, &nvvmMetadata);
|
||||||
@@ -100,6 +197,20 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Link external libraries before perform optimizations
|
||||||
|
// Note from libdevice users guide:
|
||||||
|
// https://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html
|
||||||
|
// The standard process for linking with libdevice is to first link it with
|
||||||
|
// the target module, then run the standard LLVM optimization and code
|
||||||
|
// generation passes. This allows the optimizers to inline and perform
|
||||||
|
// analyses on the used library functions, and eliminate any used functions as
|
||||||
|
// dead code.
|
||||||
|
auto externLibs = getExternLibs(module);
|
||||||
|
for (auto &lib : externLibs) {
|
||||||
|
if (linkExternLib(*llvmModule, lib.first, lib.second))
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||||
/*optLevel=*/3, /*sizeLevel=*/0,
|
/*optLevel=*/3, /*sizeLevel=*/0,
|
||||||
/*targetMachine=*/nullptr);
|
/*targetMachine=*/nullptr);
|
||||||
@@ -147,49 +258,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, std::string> externLibs;
|
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
|
||||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
if (!llvmIR) {
|
||||||
module.walk([&](LLVM::LLVMFuncOp func) {
|
|
||||||
if (func.isExternal())
|
|
||||||
funcs.push_back(func);
|
|
||||||
});
|
|
||||||
|
|
||||||
for (auto &func : funcs) {
|
|
||||||
if (func.getOperation()->hasAttr("libname")) {
|
|
||||||
auto name =
|
|
||||||
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
|
||||||
auto path =
|
|
||||||
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
|
||||||
if (name) {
|
|
||||||
std::string lib_name = name.str();
|
|
||||||
externLibs[lib_name] = path.str();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
|
||||||
auto dict = module.getOperation()
|
|
||||||
->getAttr("triton_gpu.externs")
|
|
||||||
.dyn_cast<DictionaryAttr>();
|
|
||||||
for (auto &attr : dict) {
|
|
||||||
externLibs[attr.getName().strref().trim().str()] =
|
|
||||||
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
|
||||||
if (!llvmir) {
|
|
||||||
llvm::errs() << "Translate to LLVM IR failed";
|
llvm::errs() << "Translate to LLVM IR failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
return llvmIR;
|
||||||
llvm::SMDiagnostic err;
|
|
||||||
for (auto &lib : externLibs) {
|
|
||||||
if (linkExternLib(*llvmir, lib.second))
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return llvmir;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void addExternalLibs(mlir::ModuleOp &module,
|
void addExternalLibs(mlir::ModuleOp &module,
|
||||||
@@ -211,27 +285,5 @@ void addExternalLibs(mlir::ModuleOp &module,
|
|||||||
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
|
|
||||||
llvm::SMDiagnostic err;
|
|
||||||
auto &ctx = module.getContext();
|
|
||||||
|
|
||||||
auto extMod = llvm::parseIRFile(path, err, ctx);
|
|
||||||
if (!extMod) {
|
|
||||||
llvm::errs() << "Failed to load " << path;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
extMod->setTargetTriple(module.getTargetTriple());
|
|
||||||
extMod->setDataLayout(module.getDataLayout());
|
|
||||||
|
|
||||||
if (llvm::Linker::linkModules(module, std::move(extMod),
|
|
||||||
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
|
||||||
llvm::errs() << "Failed to link " << path;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -8,7 +8,6 @@
|
|||||||
#include "llvm/MC/TargetRegistry.h"
|
#include "llvm/MC/TargetRegistry.h"
|
||||||
#include "llvm/Support/TargetSelect.h"
|
#include "llvm/Support/TargetSelect.h"
|
||||||
#include "llvm/Target/TargetMachine.h"
|
#include "llvm/Target/TargetMachine.h"
|
||||||
#include <filesystem>
|
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
@@ -31,68 +30,29 @@ static bool findAndReplace(std::string &str, const std::string &begin,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void linkExternal(llvm::Module &module) {
|
|
||||||
bool hasExternal = false;
|
|
||||||
for (auto &func : module) {
|
|
||||||
if (func.hasExternalLinkage()) {
|
|
||||||
hasExternal = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasExternal) {
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
// [triton root dir]/python/triton/language/libdevice.10.bc
|
|
||||||
static const fs::path libdevice = fs::path(__FILE__)
|
|
||||||
.parent_path()
|
|
||||||
.parent_path()
|
|
||||||
.parent_path()
|
|
||||||
.parent_path() /
|
|
||||||
"python" / "triton" / "language" /
|
|
||||||
"libdevice.10.bc";
|
|
||||||
if (mlir::triton::linkExternLib(module, libdevice.string()))
|
|
||||||
llvm::errs() << "link failed for: " << libdevice.string();
|
|
||||||
|
|
||||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
|
||||||
// this will enable fast math path in libdevice
|
|
||||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
|
||||||
// sqrt.approx.ftz.f32
|
|
||||||
auto &ctx = module.getContext();
|
|
||||||
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
|
|
||||||
llvm::Metadata *mdFour =
|
|
||||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
|
||||||
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
|
||||||
llvm::Metadata *mdOne =
|
|
||||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
|
||||||
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
|
||||||
module.addModuleFlag(reflect);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||||
linkExternal(module);
|
// LLVM version in use may not officially support target hardware.
|
||||||
|
// Supported versions for LLVM 14 are here:
|
||||||
// LLVM version in use may not officially support target hardware
|
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
|
||||||
int maxNNVMCC = 75;
|
int maxPTX = std::min(75, version);
|
||||||
|
int maxCC = std::min(86, cc);
|
||||||
// options
|
// options
|
||||||
auto options = llvm::cl::getRegisteredOptions();
|
auto options = llvm::cl::getRegisteredOptions();
|
||||||
auto *shortPtr =
|
auto *shortPtr =
|
||||||
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
||||||
assert(shortPtr);
|
assert(shortPtr);
|
||||||
shortPtr->setValue(true);
|
shortPtr->setValue(true);
|
||||||
// compute capability
|
std::string sm = "sm_" + std::to_string(maxCC);
|
||||||
std::string sm = "sm_" + std::to_string(cc);
|
|
||||||
// max PTX version
|
// max PTX version
|
||||||
int ptxMajor = version / 10;
|
int ptxMajor = maxPTX / 10;
|
||||||
int ptxMinor = version % 10;
|
int ptxMinor = maxPTX % 10;
|
||||||
// create
|
// create
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
std::string triple = "nvptx64-nvidia-cuda";
|
std::string triple = "nvptx64-nvidia-cuda";
|
||||||
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
|
std::string proc = "sm_" + std::to_string(maxCC);
|
||||||
std::string layout = "";
|
std::string layout = "";
|
||||||
std::string features = "";
|
std::string features = "";
|
||||||
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
// std::string features = "+ptx" + std::to_string(maxPTX);
|
||||||
// max_nvvm_ptx));
|
|
||||||
initLLVM();
|
initLLVM();
|
||||||
// verify and store llvm
|
// verify and store llvm
|
||||||
llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
|
@@ -15,5 +15,5 @@ def kernel(X, stride_xm,
|
|||||||
tl.store(Zs, tl.load(Xs))
|
tl.store(Zs, tl.load(Xs))
|
||||||
|
|
||||||
|
|
||||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||||
print(ret)
|
print(ret)
|
||||||
|
@@ -173,7 +173,7 @@ setup(
|
|||||||
author_email="phil@openai.com",
|
author_email="phil@openai.com",
|
||||||
description="A language and compiler for custom Deep Learning operations",
|
description="A language and compiler for custom Deep Learning operations",
|
||||||
long_description="",
|
long_description="",
|
||||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
|
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"cmake",
|
"cmake",
|
||||||
"filelock",
|
"filelock",
|
||||||
|
@@ -1267,7 +1267,7 @@ def test_arange(start, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
|
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
|
||||||
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||||
dtype = getattr(torch, dtype_str)
|
dtype = getattr(torch, dtype_str)
|
||||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||||
@@ -1286,18 +1286,18 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
|||||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||||
in_offsets = tl.arange(0, out_size)
|
in_offsets = tl.arange(0, out_size)
|
||||||
# Load inputs.
|
# Load inputs.
|
||||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
|
x = GENERATE_TEST_HERE
|
||||||
# Store output
|
# Store output
|
||||||
output_offsets = tl.arange(0, out_size)
|
output_offsets = tl.arange(0, out_size)
|
||||||
tl.store(out_ptr + output_offsets, x)
|
tl.store(out_ptr + output_offsets, x)
|
||||||
|
|
||||||
_kernel[(1,)](input, output, input_size, output_size)
|
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
|
||||||
|
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
|
||||||
|
kernel[(1,)](input, output, input_size, output_size)
|
||||||
|
|
||||||
reference_out = input
|
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
|
||||||
triton.testing.allclose(output, reference_out)
|
triton.testing.allclose(output, reference_out)
|
||||||
|
|
||||||
# 'bfloat16': torch.bfloat16,
|
|
||||||
# Testing masked loads with an intermate copy to shared memory run.
|
# Testing masked loads with an intermate copy to shared memory run.
|
||||||
|
|
||||||
|
|
||||||
|
@@ -734,10 +734,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
assert len(node.values) == 2
|
assert len(node.values) == 2
|
||||||
lhs = self.visit(node.values[0])
|
lhs = self.visit(node.values[0])
|
||||||
rhs = self.visit(node.values[1])
|
rhs = self.visit(node.values[1])
|
||||||
if isinstance(lhs, triton.language.constexpr):
|
|
||||||
lhs = lhs.value
|
|
||||||
if isinstance(rhs, triton.language.constexpr):
|
|
||||||
rhs = rhs.value
|
|
||||||
|
|
||||||
fn = {
|
fn = {
|
||||||
ast.And: 'logical_and',
|
ast.And: 'logical_and',
|
||||||
@@ -967,23 +963,12 @@ def ptx_get_version(cuda_version) -> int:
|
|||||||
'''
|
'''
|
||||||
assert isinstance(cuda_version, str)
|
assert isinstance(cuda_version, str)
|
||||||
major, minor = map(int, cuda_version.split('.'))
|
major, minor = map(int, cuda_version.split('.'))
|
||||||
version = major * 1000 + minor * 10
|
if major == 12:
|
||||||
if version >= 11040:
|
return 80 + minor
|
||||||
return 74
|
if major == 11:
|
||||||
if version >= 11030:
|
return 70 + minor
|
||||||
return 73
|
if major == 10:
|
||||||
if version >= 11020:
|
return 63 + minor
|
||||||
return 72
|
|
||||||
if version >= 11010:
|
|
||||||
return 71
|
|
||||||
if version >= 11000:
|
|
||||||
return 70
|
|
||||||
if version >= 10020:
|
|
||||||
return 65
|
|
||||||
if version >= 10010:
|
|
||||||
return 64
|
|
||||||
if version >= 10000:
|
|
||||||
return 63
|
|
||||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||||
|
|
||||||
|
|
||||||
|
@@ -403,6 +403,18 @@ class constexpr:
|
|||||||
def __neg__(self):
|
def __neg__(self):
|
||||||
return constexpr(-self.value)
|
return constexpr(-self.value)
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
return constexpr(self.value & other.value)
|
||||||
|
|
||||||
|
def logical_and(self, other):
|
||||||
|
return constexpr(self.value and other.value)
|
||||||
|
|
||||||
|
def __or__(self, other):
|
||||||
|
return constexpr(self.value | other.value)
|
||||||
|
|
||||||
|
def logical_or(self, other):
|
||||||
|
return constexpr(self.value or other.value)
|
||||||
|
|
||||||
def __pos__(self):
|
def __pos__(self):
|
||||||
return constexpr(+self.value)
|
return constexpr(+self.value)
|
||||||
|
|
||||||
@@ -818,9 +830,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
|
|||||||
'type cache_modifier: str, optional
|
'type cache_modifier: str, optional
|
||||||
"""
|
"""
|
||||||
# mask, other can be constexpr
|
# mask, other can be constexpr
|
||||||
if mask is not None:
|
if _constexpr_to_value(mask) is not None:
|
||||||
mask = _to_tensor(mask, _builder)
|
mask = _to_tensor(mask, _builder)
|
||||||
if other is not None:
|
if _constexpr_to_value(other) is not None:
|
||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
cache_modifier = _constexpr_to_value(cache_modifier)
|
cache_modifier = _constexpr_to_value(cache_modifier)
|
||||||
eviction_policy = _constexpr_to_value(eviction_policy)
|
eviction_policy = _constexpr_to_value(eviction_policy)
|
||||||
@@ -844,7 +856,7 @@ def store(pointer, value, mask=None, _builder=None):
|
|||||||
"""
|
"""
|
||||||
# value can be constexpr
|
# value can be constexpr
|
||||||
value = _to_tensor(value, _builder)
|
value = _to_tensor(value, _builder)
|
||||||
if mask is not None:
|
if _constexpr_to_value(mask) is not None:
|
||||||
mask = _to_tensor(mask, _builder)
|
mask = _to_tensor(mask, _builder)
|
||||||
return semantic.store(pointer, value, mask, _builder)
|
return semantic.store(pointer, value, mask, _builder)
|
||||||
|
|
||||||
|
@@ -1116,11 +1116,13 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|||||||
|
|
||||||
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||||
x, y = binary_op_type_checking_impl(x, y, builder)
|
x, y = binary_op_type_checking_impl(x, y, builder)
|
||||||
|
# FIXME(Keren): not portable, should be fixed
|
||||||
from . import libdevice
|
from . import libdevice
|
||||||
return libdevice.mulhi(x, y, _builder=builder)
|
return libdevice.mulhi(x, y, _builder=builder)
|
||||||
|
|
||||||
|
|
||||||
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||||
|
# FIXME(Keren): not portable, should be fixed
|
||||||
from . import libdevice
|
from . import libdevice
|
||||||
return libdevice.floor(x, _builder=builder)
|
return libdevice.floor(x, _builder=builder)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user