[CODEGEN] More work on the CPU backend
This commit is contained in:
committed by
Philippe Tillet
parent
64eaec016f
commit
840308ab5d
@@ -6,6 +6,7 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -113,7 +114,8 @@ struct scanline_layout: public data_layout {
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
analysis::align* align,
|
||||
target* tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
@@ -172,7 +174,7 @@ private:
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
@@ -190,6 +192,7 @@ private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
target* tgt_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
|
@@ -19,6 +19,7 @@ public:
|
||||
buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership);
|
||||
buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership);
|
||||
buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership);
|
||||
uintptr_t addr_as_uintptr_t();
|
||||
static buffer* create(driver::context* ctx, size_t size);
|
||||
driver::context* context();
|
||||
size_t size();
|
||||
|
@@ -9,6 +9,15 @@
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "llvm/ExecutionEngine/JITSymbol.h"
|
||||
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/Core.h"
|
||||
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
||||
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
||||
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "triton/tools/thread_pool.h"
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
@@ -42,13 +51,21 @@ struct host_context_t{
|
||||
};
|
||||
|
||||
struct host_stream_t{
|
||||
|
||||
std::shared_ptr<ThreadPool> pool;
|
||||
};
|
||||
|
||||
struct host_module_t{
|
||||
std::string error;
|
||||
llvm::ExecutionEngine* engine;
|
||||
std::map<std::string, llvm::Function*> functions;
|
||||
void(*fn)(char**, int32_t, int32_t, int32_t);
|
||||
llvm::orc::ExecutionSession* ES;
|
||||
llvm::orc::RTDyldObjectLinkingLayer* ObjectLayer;
|
||||
llvm::orc::IRCompileLayer* CompileLayer;
|
||||
llvm::DataLayout* DL;
|
||||
llvm::orc::MangleAndInterner* Mangle;
|
||||
llvm::orc::ThreadSafeContext* Ctx;
|
||||
llvm::orc::JITDylib *MainJD;
|
||||
};
|
||||
|
||||
struct host_function_t{
|
||||
|
@@ -32,7 +32,7 @@ public:
|
||||
driver::context* context() const;
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **extra = NULL) = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **args = NULL, size_t args_size = 0) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -66,7 +66,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -80,7 +80,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
@@ -15,11 +15,65 @@
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
ThreadPool(size_t threads)
|
||||
: stop(false) {
|
||||
for(size_t i = 0;i < threads;++i)
|
||||
workers.emplace_back(
|
||||
[this] {
|
||||
for(;;){
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
~ThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
@@ -32,69 +86,5 @@ private:
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: stop(false)
|
||||
{
|
||||
for(size_t i = 0;i<threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user