[code generation] search space pruning
This commit is contained in:
@@ -6,9 +6,9 @@
|
|||||||
|
|
||||||
const char* src =
|
const char* src =
|
||||||
R"(
|
R"(
|
||||||
const tunable int32 TM;
|
const tunable int32 TM = {16, 32, 64};
|
||||||
const tunable int32 TN;
|
const tunable int32 TN = {16, 32, 64};
|
||||||
const tunable int32 TK;
|
const tunable int32 TK = {8, 16};
|
||||||
|
|
||||||
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||||
int32 M, int32 N, int32 K, int32 bound){
|
int32 M, int32 N, int32 K, int32 bound){
|
||||||
@@ -26,20 +26,8 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
|||||||
pa = pa + TK*M;
|
pa = pa + TK*M;
|
||||||
pb = pb + TK*K;
|
pb = pb + TK*K;
|
||||||
k = k - TK;
|
k = k - TK;
|
||||||
int1 checka[TM, TK] = k > bound;
|
a = *pa;
|
||||||
int1 checkb[TN, TK] = k > bound;
|
b = *pb;
|
||||||
@checka a = *pa;
|
|
||||||
@checkb b = *pb;
|
|
||||||
if(k > bound)
|
|
||||||
continue;
|
|
||||||
int1 checka0[TM] = rxa < M;
|
|
||||||
int1 checka1[TK] = rka < k;
|
|
||||||
int1 checkb0[TN] = ryb < N;
|
|
||||||
int1 checkb1[TK] = rkb < k;
|
|
||||||
checka = checka0[:, newaxis] && checka1[newaxis, :];
|
|
||||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];
|
|
||||||
a = checka ? *pa : 0;
|
|
||||||
b = checkb ? *pb : 0;
|
|
||||||
}
|
}
|
||||||
int32 rxc[TM] = get_global_range[TM](0);
|
int32 rxc[TM] = get_global_range[TM](0);
|
||||||
int32 ryc[TN] = get_global_range[TN](1);
|
int32 ryc[TN] = get_global_range[TN](1);
|
||||||
@@ -87,22 +75,17 @@ T min(std::vector<T> x)
|
|||||||
|
|
||||||
|
|
||||||
template<class OP, class SYNC>
|
template<class OP, class SYNC>
|
||||||
double bench(OP const & op, SYNC const & sync)
|
double bench(OP const & op, SYNC const & sync, unsigned repeat = 20)
|
||||||
{
|
{
|
||||||
timer tmr;
|
timer tmr;
|
||||||
std::vector<size_t> times;
|
|
||||||
double total_time = 0;
|
|
||||||
op();
|
op();
|
||||||
sync();
|
sync();
|
||||||
while(total_time*1e-9 < 1e-3){
|
tmr.start();
|
||||||
float norm = 1;
|
for(unsigned i = 0; i < repeat; i++)
|
||||||
tmr.start();
|
|
||||||
op();
|
op();
|
||||||
sync();
|
sync();
|
||||||
times.push_back(norm*tmr.get().count());
|
double time = tmr.get().count();
|
||||||
total_time+=times.back();
|
return time / repeat;
|
||||||
}
|
|
||||||
return min(times);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
@@ -111,16 +94,16 @@ int main() {
|
|||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
|
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
int32_t M = 128, N = 128, K = 128;
|
int32_t M = 512, N = 512, K = 512;
|
||||||
std::vector<float> hc(M*N);
|
std::vector<float> hc(M*N);
|
||||||
std::vector<float> rc(M*N);
|
std::vector<float> rc(M*N);
|
||||||
std::vector<float> ha(M*K);
|
std::vector<float> ha(M*K);
|
||||||
std::vector<float> hb(K*N);
|
std::vector<float> hb(K*N);
|
||||||
srand(0);
|
srand(0);
|
||||||
for(size_t i = 0; i < ha.size(); i++)
|
for(size_t i = 0; i < ha.size(); i++)
|
||||||
ha[i] = 1;
|
ha[i] = (float)rand()/RAND_MAX;
|
||||||
for(size_t i = 0; i < hb.size(); i++)
|
for(size_t i = 0; i < hb.size(); i++)
|
||||||
hb[i] = 1;
|
hb[i] = (float)rand()/RAND_MAX;
|
||||||
for(size_t i = 0; i < hc.size(); i++)
|
for(size_t i = 0; i < hc.size(); i++)
|
||||||
hc[i] = 0;
|
hc[i] = 0;
|
||||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||||
@@ -163,11 +146,10 @@ int main() {
|
|||||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
// benchmark
|
// benchmark
|
||||||
// double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||||
// [&](){ stream->synchronize(); });
|
[&](){ stream->synchronize(); });
|
||||||
double ts = 1;
|
|
||||||
ts = ts * 1e-9;
|
ts = ts * 1e-9;
|
||||||
double tflops = 2*M*N*K / ts * 1e-12;
|
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||||
return tflops;
|
return tflops;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -177,11 +159,12 @@ int main() {
|
|||||||
16, 2, 64,
|
16, 2, 64,
|
||||||
32, 2, 64,
|
32, 2, 64,
|
||||||
16, 8, 2, 2,
|
16, 8, 2, 2,
|
||||||
8, 1, 8,
|
8, 8,
|
||||||
4, 1
|
4,
|
||||||
};
|
};
|
||||||
|
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
|
||||||
|
|
||||||
// jit.autotune(src, benchmark);
|
jit.autotune(src, benchmark);
|
||||||
jit.add_module(src, params);
|
jit.add_module(src, params);
|
||||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
|
@@ -94,10 +94,15 @@ abstract_declarator
|
|||||||
direct_abstract_declarator
|
direct_abstract_declarator
|
||||||
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
||||||
|
|
||||||
constant :
|
constant:
|
||||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
constant_list:
|
||||||
|
constant { $$ = new list<constant*>((constant*)$1); }
|
||||||
|
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
|
||||||
|
;
|
||||||
|
|
||||||
type_name
|
type_name
|
||||||
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
||||||
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
||||||
@@ -259,7 +264,7 @@ expression
|
|||||||
/* Initialization */
|
/* Initialization */
|
||||||
initialization_expression
|
initialization_expression
|
||||||
: assignment_expression { $$ = $1; }
|
: assignment_expression { $$ = $1; }
|
||||||
| '{' constant '}' { $$ = $2; }
|
| '{' constant_list '}' { $$ = $2; }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
|
||||||
|
@@ -38,18 +38,24 @@ class context;
|
|||||||
class device: public polymorphic_resource<CUdevice, cl_device_id, host_device_t>{
|
class device: public polymorphic_resource<CUdevice, cl_device_id, host_device_t>{
|
||||||
public:
|
public:
|
||||||
using polymorphic_resource::polymorphic_resource;
|
using polymorphic_resource::polymorphic_resource;
|
||||||
|
virtual size_t max_threads_per_block() const = 0;
|
||||||
|
virtual size_t max_shared_memory() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Host device
|
// Host device
|
||||||
class host_device: public device {
|
class host_device: public device {
|
||||||
public:
|
public:
|
||||||
host_device(): device(host_device_t(), true){ }
|
host_device(): device(host_device_t(), true){ }
|
||||||
|
size_t max_threads_per_block() const { return 1; }
|
||||||
|
size_t max_shared_memory() const { return 0; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// OpenCL device
|
// OpenCL device
|
||||||
class ocl_device: public device {
|
class ocl_device: public device {
|
||||||
public:
|
public:
|
||||||
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
|
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
|
||||||
|
size_t max_threads_per_block() const;
|
||||||
|
size_t max_shared_memory() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// CUDA device
|
// CUDA device
|
||||||
@@ -87,8 +93,6 @@ public:
|
|||||||
std::string infos() const;
|
std::string infos() const;
|
||||||
size_t address_bits() const;
|
size_t address_bits() const;
|
||||||
std::vector<size_t> max_block_dim() const;
|
std::vector<size_t> max_block_dim() const;
|
||||||
size_t max_threads_per_block() const;
|
|
||||||
size_t max_shared_memory() const;
|
|
||||||
size_t warp_size() const;
|
size_t warp_size() const;
|
||||||
//Compute Capability
|
//Compute Capability
|
||||||
void interpret_as(std::pair<size_t, size_t> cc);
|
void interpret_as(std::pair<size_t, size_t> cc);
|
||||||
@@ -99,7 +103,8 @@ public:
|
|||||||
//Clocks
|
//Clocks
|
||||||
size_t current_sm_clock() const;
|
size_t current_sm_clock() const;
|
||||||
size_t current_mem_clock() const;
|
size_t current_mem_clock() const;
|
||||||
|
size_t max_threads_per_block() const;
|
||||||
|
size_t max_shared_memory() const;
|
||||||
size_t max_sm_clock() const;
|
size_t max_sm_clock() const;
|
||||||
size_t max_mem_clock() const;
|
size_t max_mem_clock() const;
|
||||||
|
|
||||||
|
@@ -87,7 +87,7 @@ public:
|
|||||||
static bool cudnninit();
|
static bool cudnninit();
|
||||||
static void release();
|
static void release();
|
||||||
|
|
||||||
//OpenCL
|
// OpenCL
|
||||||
static cl_int clBuildProgram(cl_program, cl_uint, const cl_device_id *, const char *, void (*)(cl_program, void *), void *);
|
static cl_int clBuildProgram(cl_program, cl_uint, const cl_device_id *, const char *, void (*)(cl_program, void *), void *);
|
||||||
static cl_int clEnqueueNDRangeKernel(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *);
|
static cl_int clEnqueueNDRangeKernel(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *);
|
||||||
static cl_int clSetKernelArg(cl_kernel, cl_uint, size_t, const void *);
|
static cl_int clSetKernelArg(cl_kernel, cl_uint, size_t, const void *);
|
||||||
|
@@ -105,20 +105,21 @@ public:
|
|||||||
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
|
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class CUType>
|
template<class T>
|
||||||
class handle{
|
class handle{
|
||||||
public:
|
public:
|
||||||
template<class, class> friend class handle_interface;
|
template<class, class> friend class handle_interface;
|
||||||
public:
|
public:
|
||||||
//Constructors
|
//Constructors
|
||||||
handle(CUType cu = CUType(), bool take_ownership = true);
|
handle(T h, bool take_ownership = true);
|
||||||
|
handle();
|
||||||
~handle();
|
~handle();
|
||||||
CUType& operator*() { return *h_; }
|
T& operator*() { return *h_; }
|
||||||
CUType const & operator*() const { return *h_; }
|
T const & operator*() const { return *h_; }
|
||||||
CUType* operator->() const { return h_.get(); }
|
T* operator->() const { return h_.get(); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<CUType> h_;
|
std::shared_ptr<T> h_;
|
||||||
bool has_ownership_;
|
bool has_ownership_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -44,19 +44,19 @@ protected:
|
|||||||
};
|
};
|
||||||
|
|
||||||
/* Metaparameter int */
|
/* Metaparameter int */
|
||||||
class metaparameter: public constant_int{
|
class metaparameter: public constant_int {
|
||||||
metaparameter(type *ty, unsigned lo, unsigned hi);
|
private:
|
||||||
|
metaparameter(type *ty, const std::vector<unsigned>& space);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
|
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
|
||||||
|
static metaparameter *create(context &ctx, type *ty, const std::vector<unsigned>& space);
|
||||||
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
|
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
|
||||||
bool has_value() { return has_value_; }
|
bool has_value() { return has_value_; }
|
||||||
unsigned get_lo() { return lo_; }
|
const std::vector<unsigned>& get_space() { return space_; }
|
||||||
unsigned get_hi() { return hi_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned lo_;
|
std::vector<unsigned> space_;
|
||||||
unsigned hi_;
|
|
||||||
bool has_value_;
|
bool has_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -410,12 +410,16 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
std::string name = decl_->id()->name();
|
std::string name = decl_->id()->name();
|
||||||
ir::value *value = ir::undef_value::get(ty);
|
ir::value *value = ir::undef_value::get(ty);
|
||||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||||
assert(expr_ == nullptr);
|
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
|
||||||
//TODO: implement ranges
|
if(csts == nullptr)
|
||||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64);
|
throw std::runtime_error("must specify constant list for metaparameters");
|
||||||
|
std::vector<unsigned> values;
|
||||||
|
for(constant* cst: csts->values())
|
||||||
|
values.push_back(cst->value());
|
||||||
|
value = ir::metaparameter::create(mod->get_context(), ty, values);
|
||||||
mod->register_global(name, value);
|
mod->register_global(name, value);
|
||||||
}
|
}
|
||||||
if(expr_){
|
else if(expr_){
|
||||||
value = expr_->codegen(mod);
|
value = expr_->codegen(mod);
|
||||||
value = explicit_cast(mod->get_builder(), value, ty);
|
value = explicit_cast(mod->get_builder(), value, ty);
|
||||||
implicit_broadcast(mod, value, ty);
|
implicit_broadcast(mod, value, ty);
|
||||||
|
@@ -144,11 +144,23 @@ void tune::run(ir::module &mod) {
|
|||||||
// Layout parameters
|
// Layout parameters
|
||||||
while(!nodes_.empty()){
|
while(!nodes_.empty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2);
|
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simplify metaparameters
|
||||||
|
std::set<ir::metaparameter*> fixed_io_nts;
|
||||||
|
for(ir::function *fn: mod.get_function_list())
|
||||||
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
for(ir::instruction *i : block->get_inst_list())
|
||||||
|
if(dynamic_cast<ir::load_inst*>(i) || dynamic_cast<ir::store_inst*>(i))
|
||||||
|
if(i->get_type()->is_tile_ty())
|
||||||
|
for(unsigned d = 1; d < i->get_type()->get_tile_shapes().size(); d++)
|
||||||
|
fixed_io_nts.insert(params_.at(i).at("nts.d" + std::to_string(d)));
|
||||||
|
for(ir::metaparameter* mp: fixed_io_nts)
|
||||||
|
mp->set_value(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void tune::init(ir::module &mod) {
|
void tune::init(ir::module &mod) {
|
||||||
@@ -234,7 +246,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
|||||||
int num_threads = 1;
|
int num_threads = 1;
|
||||||
for(size_t k = 0; k < shapes.size(); k++)
|
for(size_t k = 0; k < shapes.size(); k++)
|
||||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||||
if(num_threads % 32 != 0)
|
if(num_threads % 64 != 0)
|
||||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
|
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
|
||||||
if(num_threads != num_threads_)
|
if(num_threads != num_threads_)
|
||||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||||
|
@@ -25,7 +25,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include "triton/driver/helpers/CL/infos.hpp"
|
||||||
#include "triton/driver/device.h"
|
#include "triton/driver/device.h"
|
||||||
#include "triton/driver/context.h"
|
#include "triton/driver/context.h"
|
||||||
|
|
||||||
@@ -40,6 +40,14 @@ namespace driver
|
|||||||
// OpenCL //
|
// OpenCL //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
|
|
||||||
|
// maximum amount of shared memory per block
|
||||||
|
size_t ocl_device::max_shared_memory() const {
|
||||||
|
return ocl::info<CL_DEVICE_LOCAL_MEM_SIZE>(*cl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ocl_device::max_threads_per_block() const {
|
||||||
|
return ocl::info<CL_DEVICE_MAX_WORK_ITEM_SIZES>(*cl_).at(0);
|
||||||
|
}
|
||||||
|
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
// CUDA //
|
// CUDA //
|
||||||
|
@@ -60,13 +60,16 @@ inline void _delete(cu_event_t x) { _delete(x.first); _delete(x.second); }
|
|||||||
inline void _delete(CUPlatform){}
|
inline void _delete(CUPlatform){}
|
||||||
|
|
||||||
//Constructor
|
//Constructor
|
||||||
template<class CUType>
|
template<class T>
|
||||||
handle<CUType>::handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership)
|
handle<T>::handle(T cu, bool take_ownership): h_(new T(cu)), has_ownership_(take_ownership)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
handle<T>::handle(): has_ownership_(false){ }
|
||||||
|
|
||||||
template<class CUType>
|
|
||||||
handle<CUType>::~handle(){
|
template<class T>
|
||||||
|
handle<T>::~handle(){
|
||||||
if(has_ownership_ && h_ && h_.unique())
|
if(has_ownership_ && h_ && h_.unique())
|
||||||
_delete(*h_);
|
_delete(*h_);
|
||||||
}
|
}
|
||||||
|
@@ -53,6 +53,10 @@
|
|||||||
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
|
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
|
||||||
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
||||||
#include "llvm/Transforms/Utils/Cloning.h"
|
#include "llvm/Transforms/Utils/Cloning.h"
|
||||||
|
#include "lld/Common/Driver.h"
|
||||||
|
#include "lld/Common/Args.h"
|
||||||
|
#include "lld/Common/ErrorHandler.h"
|
||||||
|
#include "lld/Common/LLVM.h"
|
||||||
|
|
||||||
namespace triton
|
namespace triton
|
||||||
{
|
{
|
||||||
@@ -110,36 +114,17 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
|||||||
std::string error;
|
std::string error;
|
||||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||||
llvm::TargetOptions opt;
|
llvm::TargetOptions opt;
|
||||||
// opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||||
// opt.UnsafeFPMath = false;
|
opt.UnsafeFPMath = false;
|
||||||
// opt.NoInfsFPMath = false;
|
opt.NoInfsFPMath = false;
|
||||||
// opt.NoNaNsFPMath = true;
|
opt.NoNaNsFPMath = true;
|
||||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "code-object-v3", opt,
|
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "code-object-v3", opt,
|
||||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||||
|
|
||||||
// set data layout
|
// set data layout
|
||||||
if(layout.empty())
|
if(layout.empty())
|
||||||
module->setDataLayout(machine->createDataLayout());
|
module->setDataLayout(machine->createDataLayout());
|
||||||
else
|
else
|
||||||
module->setDataLayout(layout);
|
module->setDataLayout(layout);
|
||||||
|
|
||||||
// link
|
|
||||||
for (std::string& path: paths) {
|
|
||||||
llvm::SMDiagnostic err;
|
|
||||||
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, module->getContext());
|
|
||||||
if (mlib.get() == nullptr) {
|
|
||||||
std::string msg = err.getMessage();
|
|
||||||
std::cerr << "Fail to load bitcode file " << path << "\n"
|
|
||||||
<< "line " << err.getLineNo() << ":" << msg;
|
|
||||||
}
|
|
||||||
mlib->setTargetTriple(module->getTargetTriple());
|
|
||||||
mlib->setDataLayout(module->getDataLayout());
|
|
||||||
for (llvm::Function &f : mlib->functions()) {
|
|
||||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
|
||||||
}
|
|
||||||
llvm::Linker::linkModules(*module, std::move(mlib));
|
|
||||||
}
|
|
||||||
|
|
||||||
// emit machine code
|
// emit machine code
|
||||||
for (llvm::Function &f : module->functions())
|
for (llvm::Function &f : module->functions())
|
||||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||||
@@ -187,12 +172,10 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
|
|||||||
init_llvm();
|
init_llvm();
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
|
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
|
||||||
|
std::ofstream output("/tmp/tmp.o", std::ios::binary);
|
||||||
std::ofstream output("tmp.o", std::ios::binary);
|
|
||||||
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
|
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
|
||||||
system("ld.lld tmp.o -shared -o test.o");
|
system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o");
|
||||||
|
std::ifstream input("/tmp/tmp.o", std::ios::in | std::ios::binary );
|
||||||
std::ifstream input("test.o", std::ios::in | std::ios::binary );
|
|
||||||
std::vector<unsigned char> in_buffer(std::istreambuf_iterator<char>(input), {});
|
std::vector<unsigned char> in_buffer(std::istreambuf_iterator<char>(input), {});
|
||||||
size_t sizes[] = {in_buffer.size()};
|
size_t sizes[] = {in_buffer.size()};
|
||||||
const unsigned char* data[] = {(unsigned char*)in_buffer.data()};
|
const unsigned char* data[] = {(unsigned char*)in_buffer.data()};
|
||||||
@@ -208,7 +191,6 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
|
|||||||
char log[2048];
|
char log[2048];
|
||||||
dispatch::clGetProgramBuildInfo(*cl_, *context->device()->cl(), CL_PROGRAM_BUILD_LOG, 1024, log, NULL);
|
dispatch::clGetProgramBuildInfo(*cl_, *context->device()->cl(), CL_PROGRAM_BUILD_LOG, 1024, log, NULL);
|
||||||
std::cout << log << std::endl;
|
std::cout << log << std::endl;
|
||||||
std::cout << "T_T" << std::endl;
|
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -111,7 +111,8 @@ void cl_stream::synchronize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
|
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
|
||||||
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)grid.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
|
||||||
|
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
||||||
}
|
}
|
||||||
|
|
||||||
void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||||
|
@@ -98,12 +98,22 @@ constant *constant_fp::get(context &ctx, double v){
|
|||||||
}
|
}
|
||||||
|
|
||||||
// metaparameter
|
// metaparameter
|
||||||
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
|
metaparameter::metaparameter(type *ty, const std::vector<unsigned> &space)
|
||||||
: constant_int(ty, 0), lo_(lo), hi_(hi), has_value_(false){ }
|
: constant_int(ty, 0), space_(space), has_value_(false){ }
|
||||||
|
|
||||||
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
|
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
|
||||||
context_impl *impl = ctx.p_impl.get();
|
context_impl *impl = ctx.p_impl.get();
|
||||||
metaparameter *result = new metaparameter(ty, lo, hi);
|
std::vector<unsigned> space;
|
||||||
|
for(unsigned i = lo; i <= hi; i *= 2)
|
||||||
|
space.push_back(i);
|
||||||
|
metaparameter *result = new metaparameter(ty, space);
|
||||||
|
impl->mp_constants_.push_back(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<unsigned> &space) {
|
||||||
|
context_impl *impl = ctx.p_impl.get();
|
||||||
|
metaparameter *result = new metaparameter(ty, space);
|
||||||
impl->mp_constants_.push_back(result);
|
impl->mp_constants_.push_back(result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
33
lib/jit.cpp
33
lib/jit.cpp
@@ -5,6 +5,7 @@
|
|||||||
#include "triton/ir/context.h"
|
#include "triton/ir/context.h"
|
||||||
#include "triton/ir/context_impl.h"
|
#include "triton/ir/context_impl.h"
|
||||||
#include "triton/driver/device.h"
|
#include "triton/driver/device.h"
|
||||||
|
#include "triton/driver/error.h"
|
||||||
#include "llvm/IR/IRPrintingPasses.h"
|
#include "llvm/IR/IRPrintingPasses.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/LLVMContext.h"
|
#include "llvm/IR/LLVMContext.h"
|
||||||
@@ -71,6 +72,7 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
|||||||
passes.selection.run(module, *result);
|
passes.selection.run(module, *result);
|
||||||
// launch information
|
// launch information
|
||||||
auto &launch_info_map = launch_info_map_[result->getName()];
|
auto &launch_info_map = launch_info_map_[result->getName()];
|
||||||
|
launch_info_map.global_range_size.clear();
|
||||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||||
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||||
launch_info_map.num_threads = passes.tune.get_num_threads();
|
launch_info_map.num_threads = passes.tune.get_num_threads();
|
||||||
@@ -104,12 +106,8 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
|||||||
auto mps = passes.tune.get_params(tt_module);
|
auto mps = passes.tune.get_params(tt_module);
|
||||||
// create parameter ranges
|
// create parameter ranges
|
||||||
std::vector<std::vector<unsigned>> ranges;
|
std::vector<std::vector<unsigned>> ranges;
|
||||||
for(ir::metaparameter *mp: mps){
|
for(ir::metaparameter *mp: mps)
|
||||||
std::vector<unsigned> current;
|
ranges.push_back(mp->get_space());
|
||||||
for(unsigned x = mp->get_lo(); x <= mp->get_hi(); x*=2)
|
|
||||||
current.push_back(x);
|
|
||||||
ranges.push_back(current);
|
|
||||||
}
|
|
||||||
// iterate over parameters
|
// iterate over parameters
|
||||||
unsigned i;
|
unsigned i;
|
||||||
double best = 0;
|
double best = 0;
|
||||||
@@ -132,22 +130,23 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
|||||||
}
|
}
|
||||||
passes.tune.init(tt_module);
|
passes.tune.init(tt_module);
|
||||||
passes.init(tt_module);
|
passes.init(tt_module);
|
||||||
// driver::device* device = driver_context_->device();
|
driver::device* device = driver_context_->device();
|
||||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||||
// return;
|
return;
|
||||||
// if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||||
// return;
|
return;
|
||||||
// Compile
|
// Compile
|
||||||
auto ll_module = make_llvm_module(tt_module, passes);
|
auto ll_module = make_llvm_module(tt_module, passes);
|
||||||
driver::module* module = driver::module::create(driver_context_, &*ll_module);
|
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||||
driver::kernel* kernel = driver::kernel::create(module, "matmul");
|
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
|
||||||
launch_information info = launch_info_map_.at("matmul");
|
launch_information info = launch_info_map_.at("matmul");
|
||||||
for(unsigned p: params)
|
for(unsigned p: params)
|
||||||
std::cout << p << " " << std::flush;
|
std::cout << p << " " << std::flush;
|
||||||
// add globals
|
// add globals
|
||||||
for(auto x: tt_module.globals())
|
for(auto x: tt_module.globals())
|
||||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||||
double perf = benchmark(kernel, info);
|
double perf;
|
||||||
|
perf = benchmark(kernel.get(), info);
|
||||||
best = std::max(perf, best);
|
best = std::max(perf, best);
|
||||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||||
});
|
});
|
||||||
@@ -167,9 +166,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
|||||||
passes.tune.check_constraints(errors);
|
passes.tune.check_constraints(errors);
|
||||||
if(errors.size())
|
if(errors.size())
|
||||||
throw std::runtime_error("invalid parameters");
|
throw std::runtime_error("invalid parameters");
|
||||||
// driver::device* device = driver_context_->device();
|
driver::device* device = driver_context_->device();
|
||||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||||
// throw std::runtime_error("invalid parameters");
|
throw std::runtime_error("invalid parameters");
|
||||||
// triton module -> llvm module
|
// triton module -> llvm module
|
||||||
auto ll_module = make_llvm_module(tt_module, passes);
|
auto ll_module = make_llvm_module(tt_module, passes);
|
||||||
// llvm module -> machine code
|
// llvm module -> machine code
|
||||||
|
Reference in New Issue
Block a user