[CODEGEN] More work on the CPU backend
This commit is contained in:
committed by
Philippe Tillet
parent
64eaec016f
commit
840308ab5d
@@ -6,6 +6,7 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -113,7 +114,8 @@ struct scanline_layout: public data_layout {
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
analysis::align* align,
|
||||
target* tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
@@ -172,7 +174,7 @@ private:
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
@@ -190,6 +192,7 @@ private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
target* tgt_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
|
@@ -19,6 +19,7 @@ public:
|
||||
buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership);
|
||||
buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership);
|
||||
buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership);
|
||||
uintptr_t addr_as_uintptr_t();
|
||||
static buffer* create(driver::context* ctx, size_t size);
|
||||
driver::context* context();
|
||||
size_t size();
|
||||
|
@@ -9,6 +9,15 @@
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "llvm/ExecutionEngine/JITSymbol.h"
|
||||
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/Core.h"
|
||||
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
||||
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
||||
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "triton/tools/thread_pool.h"
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
@@ -42,13 +51,21 @@ struct host_context_t{
|
||||
};
|
||||
|
||||
struct host_stream_t{
|
||||
|
||||
std::shared_ptr<ThreadPool> pool;
|
||||
};
|
||||
|
||||
struct host_module_t{
|
||||
std::string error;
|
||||
llvm::ExecutionEngine* engine;
|
||||
std::map<std::string, llvm::Function*> functions;
|
||||
void(*fn)(char**, int32_t, int32_t, int32_t);
|
||||
llvm::orc::ExecutionSession* ES;
|
||||
llvm::orc::RTDyldObjectLinkingLayer* ObjectLayer;
|
||||
llvm::orc::IRCompileLayer* CompileLayer;
|
||||
llvm::DataLayout* DL;
|
||||
llvm::orc::MangleAndInterner* Mangle;
|
||||
llvm::orc::ThreadSafeContext* Ctx;
|
||||
llvm::orc::JITDylib *MainJD;
|
||||
};
|
||||
|
||||
struct host_function_t{
|
||||
|
@@ -32,7 +32,7 @@ public:
|
||||
driver::context* context() const;
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **extra = NULL) = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **args = NULL, size_t args_size = 0) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -66,7 +66,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -80,7 +80,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
@@ -15,11 +15,65 @@
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
ThreadPool(size_t threads)
|
||||
: stop(false) {
|
||||
for(size_t i = 0;i < threads;++i)
|
||||
workers.emplace_back(
|
||||
[this] {
|
||||
for(;;){
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
~ThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
@@ -32,69 +86,5 @@ private:
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: stop(false)
|
||||
{
|
||||
for(size_t i = 0;i<threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -168,9 +168,9 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
|
||||
analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){
|
||||
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = num_warps * 32;
|
||||
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
|
||||
nts_.resize(shape_.size());
|
||||
mts_.resize(shape_.size());
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
@@ -324,8 +324,8 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
* ---- Layouts Inference Pass ---- *
|
||||
* -------------------------------- */
|
||||
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps) { }
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
|
||||
|
||||
|
||||
void layouts::connect(ir::value *x, ir::value *y) {
|
||||
@@ -382,7 +382,7 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
}
|
||||
else
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
|
@@ -488,41 +488,47 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
ptr = gep->getPointerOperand();
|
||||
}
|
||||
ptr = builder_->CreateBitCast(ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {pred->getType(), ptr->getType()};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
arg_ty.push_back(ty->getScalarType());
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false);
|
||||
// asm string
|
||||
std::string asm_str;
|
||||
asm_str += "@$0 st.global";
|
||||
if(vector_size > 1)
|
||||
asm_str += ".v" + std::to_string(vector_size);
|
||||
asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],";
|
||||
if(vector_size > 1)
|
||||
asm_str += "{";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
if(v > 0)
|
||||
asm_str += ", ";
|
||||
asm_str += "$" + std::to_string(2 + v);
|
||||
if(tgt_->is_gpu()){
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {pred->getType(), ptr->getType()};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
arg_ty.push_back(ty->getScalarType());
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false);
|
||||
// asm string
|
||||
std::string asm_str;
|
||||
asm_str += "@$0 st.global";
|
||||
if(vector_size > 1)
|
||||
asm_str += ".v" + std::to_string(vector_size);
|
||||
asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],";
|
||||
if(vector_size > 1)
|
||||
asm_str += "{";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
if(v > 0)
|
||||
asm_str += ", ";
|
||||
asm_str += "$" + std::to_string(2 + v);
|
||||
}
|
||||
if(vector_size > 1)
|
||||
asm_str += "}";
|
||||
asm_str += ";";
|
||||
// asm constraint
|
||||
std::string constraint = "b,l";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
constraint += ",";
|
||||
constraint += (nbits == 32 ? "r" : "h");
|
||||
}
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v)));
|
||||
builder_->CreateCall(iasm, args);
|
||||
}
|
||||
if(vector_size > 1)
|
||||
asm_str += "}";
|
||||
asm_str += ";";
|
||||
// asm constraint
|
||||
std::string constraint = "b,l";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
constraint += ",";
|
||||
constraint += (nbits == 32 ? "r" : "h");
|
||||
else{
|
||||
builder_->CreateMaskedStore(elt, ptr, alignment, builder_->CreateVectorSplat(vector_size, pred));
|
||||
}
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v)));
|
||||
builder_->CreateCall(iasm, args);
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1302,17 +1308,22 @@ void generator::visit_function(ir::function* fn) {
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
if(attr.is_llvm_attr())
|
||||
ret->addAttribute(id, llvm_attr(ctx, attr));
|
||||
if(attr.is_llvm_attr()){
|
||||
llvm::Attribute llattr = llvm_attr(ctx, attr);
|
||||
if(llattr.getKindAsEnum() != llvm::Attribute::None)
|
||||
ret->addAttribute(id, llvm_attr(ctx, attr));
|
||||
}
|
||||
}
|
||||
// set metadata
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(builder_->getInt32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
if(tgt_->is_gpu()){
|
||||
tgt_->set_kernel(*builder_, ctx, mod_, ret);
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(ret),
|
||||
MDString::get(ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(builder_->getInt32(num_warps_*32))
|
||||
};
|
||||
mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
}
|
||||
// set arguments
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vmap_[fn->args()[i]] = &*(ret->arg_begin() + i);
|
||||
|
@@ -47,6 +47,12 @@ void backend::platforms::init() {
|
||||
if(dispatch::cuinit()){
|
||||
cache_.push_back(new cu_platform());
|
||||
}
|
||||
//if host should be added
|
||||
bool host_visible = true;
|
||||
if(host_visible){
|
||||
cache_.push_back(new host_platform());
|
||||
}
|
||||
|
||||
// //if OpenCL is here
|
||||
// if(dispatch::clinit()){
|
||||
// cl_uint num_platforms;
|
||||
@@ -56,11 +62,7 @@ void backend::platforms::init() {
|
||||
// for(cl_platform_id id: ids)
|
||||
// cache_.push_back(new cl_platform(id));
|
||||
// }
|
||||
// //if host is here
|
||||
// bool host_visible = true;
|
||||
// if(host_visible){
|
||||
// cache_.push_back(new host_platform());
|
||||
// }
|
||||
|
||||
if(cache_.empty())
|
||||
throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path");
|
||||
}
|
||||
|
@@ -53,6 +53,14 @@ size_t buffer::size() {
|
||||
return size_;
|
||||
}
|
||||
|
||||
uintptr_t buffer::addr_as_uintptr_t() {
|
||||
switch(backend_){
|
||||
case CUDA: return *cu_;
|
||||
case Host: return (uintptr_t)hst_->data;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
buffer* buffer::create(driver::context* ctx, size_t size) {
|
||||
switch(ctx->backend()){
|
||||
|
@@ -135,12 +135,6 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
||||
|
||||
host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, host_module_t(), true) {
|
||||
init_llvm();
|
||||
// host info
|
||||
// std::string triple = llvm::sys::getDefaultTargetTriple();
|
||||
// std::string cpu = llvm::sys::getHostCPUName();
|
||||
// llvm::SmallVector<char, 0> buffer;
|
||||
// module::compile_llvm_module(src, triple, cpu, "", buffer, "", Assembly);
|
||||
|
||||
// create kernel wrapper
|
||||
llvm::LLVMContext &ctx = src->getContext();
|
||||
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
|
||||
@@ -148,37 +142,72 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
|
||||
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
|
||||
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
|
||||
llvm::Function* fn = src->getFunction("matmul");
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src);
|
||||
llvm::Function* fn = &*src->getFunctionList().begin();
|
||||
llvm::FunctionType *fn_ty = fn->getFunctionType();
|
||||
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
|
||||
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
|
||||
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
|
||||
llvm::IRBuilder<> ir_builder(ctx);
|
||||
ir_builder.SetInsertPoint(entry);
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
ptrs[i] = ir_builder.CreateGEP(main->arg_begin(), ir_builder.getInt32(i));
|
||||
auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; };
|
||||
llvm::Value* base = main->arg_begin();
|
||||
llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType());
|
||||
|
||||
size_t offset = 0;
|
||||
for(unsigned i = 0; i < ptrs.size(); i++){
|
||||
llvm::Value* addr = ir_builder.CreateBitCast(ir_builder.CreateLoad(ptrs[i]), fn_ty->getParamType(i)->getPointerTo());
|
||||
fn_args[i] = ir_builder.CreateLoad(addr);
|
||||
ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset));
|
||||
size_t nbytes = get_size(fn_ty->getParamType(i));
|
||||
offset += nbytes;
|
||||
if(i < ptrs.size() - 1){
|
||||
size_t np1bytes = get_size(fn_ty->getParamType(i+1));
|
||||
offset = (offset + np1bytes - 1) / np1bytes * np1bytes;
|
||||
}
|
||||
}
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo());
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
fn_args[i] = ir_builder.CreateLoad(ptrs[i]);
|
||||
|
||||
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
|
||||
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
|
||||
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
|
||||
ir_builder.CreateCall(fn, fn_args);
|
||||
ir_builder.CreateRetVoid();
|
||||
|
||||
// llvm::legacy::PassManager pm;
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pm.add(llvm::createVerifierPass());
|
||||
// pm.run(*src);
|
||||
|
||||
// create execution engine
|
||||
// create execution engine
|
||||
for(llvm::Function& fn: src->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
|
||||
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
// auto DL = JTMB.getDefaultDataLayoutForTarget();
|
||||
// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB));
|
||||
// hst_->ES = new llvm::orc::ExecutionSession();
|
||||
// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); });
|
||||
// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC);
|
||||
// hst_->DL = new llvm::DataLayout(std::move(*DL));
|
||||
// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL);
|
||||
// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext()));
|
||||
// hst_->MainJD = &hst_->ES->createJITDylib("<main>");
|
||||
// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
|
||||
// hst_->DL->getGlobalPrefix())));
|
||||
// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx)));
|
||||
// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress());
|
||||
|
||||
|
||||
|
||||
llvm::EngineBuilder builder(std::move(src));
|
||||
builder.setErrorStr(&hst_->error);
|
||||
builder.setMCJITMemoryManager(llvm::make_unique<llvm::SectionMemoryManager>());
|
||||
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
|
||||
builder.setEngineKind(llvm::EngineKind::JIT);
|
||||
builder.setUseOrcMCJITReplacement(true);
|
||||
hst_->engine = builder.create();
|
||||
hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main"));
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> host_module::symbol(const char *name) const {
|
||||
|
@@ -72,21 +72,20 @@ driver::context* stream::context() const {
|
||||
/* ------------------------ */
|
||||
|
||||
host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) {
|
||||
|
||||
hst_->pool.reset(new ThreadPool(8));
|
||||
}
|
||||
|
||||
void host_stream::synchronize() {
|
||||
|
||||
hst_->pool.reset(new ThreadPool(8));
|
||||
}
|
||||
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
|
||||
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
|
||||
void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main");
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **args, size_t args_size) {
|
||||
ThreadPool pool(4);
|
||||
auto hst = kernel->module()->hst();
|
||||
for(size_t i = 0; i < grid[0]; i++)
|
||||
for(size_t j = 0; j < grid[1]; j++)
|
||||
for(size_t k = 0; k < grid[2]; k++)
|
||||
fn((char**)hst_kernel->params().data(), int32_t(i), int32_t(j), int32_t(k));
|
||||
hst_->pool->enqueue(hst->fn, (char**)args, int32_t(i), int32_t(j), int32_t(k));
|
||||
}
|
||||
|
||||
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
@@ -112,7 +111,7 @@ void cl_stream::synchronize() {
|
||||
check(dispatch::clFinish(*cl_));
|
||||
}
|
||||
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **args, size_t args_size) {
|
||||
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
|
||||
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
||||
}
|
||||
@@ -149,11 +148,16 @@ void cu_stream::synchronize() {
|
||||
dispatch::cuStreamSynchronize(*cu_);
|
||||
}
|
||||
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** extra) {
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** args, size_t args_size) {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->first, *cu_);
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, extra);
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config);
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->second, *cu_);
|
||||
}
|
||||
|
@@ -163,11 +163,6 @@ function::caller::caller(ir::function *ir,
|
||||
|
||||
|
||||
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const {
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
// set grid
|
||||
if(_grid.size() > 3)
|
||||
throw std::runtime_error("grid size must be no greater than 3");
|
||||
@@ -175,7 +170,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
|
||||
for(size_t i = 0; i < 3; i++)
|
||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||
// enqueue
|
||||
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, config);
|
||||
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, args, args_size);
|
||||
}
|
||||
|
||||
|
||||
@@ -203,7 +198,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps);
|
||||
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
@@ -220,15 +215,18 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
cts.run(module);
|
||||
if(target->is_gpu())
|
||||
cts.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
coalesce.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
dce.run(module);
|
||||
reassociate.run(module);
|
||||
cts.run(module);
|
||||
if(target->is_gpu()){
|
||||
reassociate.run(module);
|
||||
cts.run(module);
|
||||
}
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
@@ -260,11 +258,11 @@ function::caller* function::make(driver::stream *stream, options_t opt) {
|
||||
auto ir = make_ir(parser);
|
||||
// triton-ir -> binary
|
||||
std::unique_ptr<driver::module> bin;
|
||||
try{
|
||||
// try{
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
}catch(const std::runtime_error&){
|
||||
return nullptr;
|
||||
}
|
||||
// }catch(const std::runtime_error&){
|
||||
// return nullptr;
|
||||
// }
|
||||
// create callable
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
caller* ret = new caller(tmp, std::move(bin), opt);
|
||||
|
@@ -74,7 +74,7 @@ class CMakeBuild(build_ext):
|
||||
'-DLLVM_CONFIG=' + find_llvm()]
|
||||
# configuration
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
cfg = 'Release'
|
||||
cfg = 'Debug'
|
||||
build_args = ['--config', cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
|
@@ -15,7 +15,7 @@ using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
typedef std::pair<size_t, size_t> map_key_t;
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<size_t, double> fp64scalar_map;
|
||||
|
@@ -8,22 +8,31 @@
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x);
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<size_t, size_t> map_key_t;
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::shared_ptr<drv::device> host_device;
|
||||
std::shared_ptr<drv::context> host_context;
|
||||
std::shared_ptr<drv::stream> host_stream;
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
if(dev_id == -1){
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(&*host_context));
|
||||
}
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
}
|
||||
else{
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@@ -67,6 +67,7 @@ class kernel:
|
||||
for x in args:
|
||||
if isinstance(x, torch.Tensor):
|
||||
device = x.device.index
|
||||
device = -1 if device is None else device
|
||||
break
|
||||
# lazily register function for device
|
||||
if device not in self.registered:
|
||||
|
@@ -14,9 +14,9 @@
|
||||
|
||||
|
||||
struct dot_arg_t{
|
||||
CUdeviceptr a;
|
||||
CUdeviceptr b;
|
||||
CUdeviceptr c;
|
||||
uintptr_t a;
|
||||
uintptr_t b;
|
||||
uintptr_t c;
|
||||
float alpha;
|
||||
int M;
|
||||
int N;
|
||||
@@ -24,7 +24,7 @@ struct dot_arg_t{
|
||||
int lda;
|
||||
int ldb;
|
||||
int ldc;
|
||||
CUdeviceptr locks;
|
||||
uintptr_t locks;
|
||||
};
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
@@ -98,7 +98,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
|
||||
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
||||
// ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
||||
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
@@ -127,17 +127,17 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
opt.num_warps = {nwarp};
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.defines.push_back({"TZ", {"1"}});
|
||||
opt.num_warps = {4};
|
||||
}
|
||||
|
||||
// kernels
|
||||
rt::function function(src::dot, opt);
|
||||
dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(),
|
||||
1, M, N, K, lda, ldb, ldc, *dlocks->cu()};
|
||||
dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(),
|
||||
1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()};
|
||||
|
||||
auto grid = [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
|
Reference in New Issue
Block a user