[code generation] search space pruning
This commit is contained in:
@@ -6,9 +6,9 @@
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM;
|
||||
const tunable int32 TN;
|
||||
const tunable int32 TK;
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8, 16};
|
||||
|
||||
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K, int32 bound){
|
||||
@@ -26,20 +26,8 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
pa = pa + TK*M;
|
||||
pb = pb + TK*K;
|
||||
k = k - TK;
|
||||
int1 checka[TM, TK] = k > bound;
|
||||
int1 checkb[TN, TK] = k > bound;
|
||||
@checka a = *pa;
|
||||
@checkb b = *pb;
|
||||
if(k > bound)
|
||||
continue;
|
||||
int1 checka0[TM] = rxa < M;
|
||||
int1 checka1[TK] = rka < k;
|
||||
int1 checkb0[TN] = ryb < N;
|
||||
int1 checkb1[TK] = rkb < k;
|
||||
checka = checka0[:, newaxis] && checka1[newaxis, :];
|
||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
@@ -87,22 +75,17 @@ T min(std::vector<T> x)
|
||||
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync)
|
||||
double bench(OP const & op, SYNC const & sync, unsigned repeat = 20)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
tmr.start();
|
||||
tmr.start();
|
||||
for(unsigned i = 0; i < repeat; i++)
|
||||
op();
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
return min(times);
|
||||
sync();
|
||||
double time = tmr.get().count();
|
||||
return time / repeat;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -111,16 +94,16 @@ int main() {
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 128, N = 128, K = 128;
|
||||
int32_t M = 512, N = 512, K = 512;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
std::vector<float> hb(K*N);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = 1;
|
||||
ha[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = 1;
|
||||
hb[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
@@ -163,11 +146,10 @@ int main() {
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
// double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
// [&](){ stream->synchronize(); });
|
||||
double ts = 1;
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); });
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2*M*N*K / ts * 1e-12;
|
||||
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||
return tflops;
|
||||
};
|
||||
|
||||
@@ -177,11 +159,12 @@ int main() {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 1, 8,
|
||||
4, 1
|
||||
8, 8,
|
||||
4,
|
||||
};
|
||||
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
|
||||
|
||||
// jit.autotune(src, benchmark);
|
||||
jit.autotune(src, benchmark);
|
||||
jit.add_module(src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
|
@@ -94,10 +94,15 @@ abstract_declarator
|
||||
direct_abstract_declarator
|
||||
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
||||
|
||||
constant :
|
||||
constant:
|
||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||
;
|
||||
|
||||
|
||||
constant_list:
|
||||
constant { $$ = new list<constant*>((constant*)$1); }
|
||||
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
|
||||
;
|
||||
|
||||
type_name
|
||||
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
||||
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
||||
@@ -259,7 +264,7 @@ expression
|
||||
/* Initialization */
|
||||
initialization_expression
|
||||
: assignment_expression { $$ = $1; }
|
||||
| '{' constant '}' { $$ = $2; }
|
||||
| '{' constant_list '}' { $$ = $2; }
|
||||
;
|
||||
|
||||
|
||||
|
@@ -38,18 +38,24 @@ class context;
|
||||
class device: public polymorphic_resource<CUdevice, cl_device_id, host_device_t>{
|
||||
public:
|
||||
using polymorphic_resource::polymorphic_resource;
|
||||
virtual size_t max_threads_per_block() const = 0;
|
||||
virtual size_t max_shared_memory() const = 0;
|
||||
};
|
||||
|
||||
// Host device
|
||||
class host_device: public device {
|
||||
public:
|
||||
host_device(): device(host_device_t(), true){ }
|
||||
size_t max_threads_per_block() const { return 1; }
|
||||
size_t max_shared_memory() const { return 0; }
|
||||
};
|
||||
|
||||
// OpenCL device
|
||||
class ocl_device: public device {
|
||||
public:
|
||||
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
|
||||
size_t max_threads_per_block() const;
|
||||
size_t max_shared_memory() const;
|
||||
};
|
||||
|
||||
// CUDA device
|
||||
@@ -87,8 +93,6 @@ public:
|
||||
std::string infos() const;
|
||||
size_t address_bits() const;
|
||||
std::vector<size_t> max_block_dim() const;
|
||||
size_t max_threads_per_block() const;
|
||||
size_t max_shared_memory() const;
|
||||
size_t warp_size() const;
|
||||
//Compute Capability
|
||||
void interpret_as(std::pair<size_t, size_t> cc);
|
||||
@@ -99,7 +103,8 @@ public:
|
||||
//Clocks
|
||||
size_t current_sm_clock() const;
|
||||
size_t current_mem_clock() const;
|
||||
|
||||
size_t max_threads_per_block() const;
|
||||
size_t max_shared_memory() const;
|
||||
size_t max_sm_clock() const;
|
||||
size_t max_mem_clock() const;
|
||||
|
||||
|
@@ -87,7 +87,7 @@ public:
|
||||
static bool cudnninit();
|
||||
static void release();
|
||||
|
||||
//OpenCL
|
||||
// OpenCL
|
||||
static cl_int clBuildProgram(cl_program, cl_uint, const cl_device_id *, const char *, void (*)(cl_program, void *), void *);
|
||||
static cl_int clEnqueueNDRangeKernel(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *);
|
||||
static cl_int clSetKernelArg(cl_kernel, cl_uint, size_t, const void *);
|
||||
|
@@ -105,20 +105,21 @@ public:
|
||||
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
|
||||
};
|
||||
|
||||
template<class CUType>
|
||||
template<class T>
|
||||
class handle{
|
||||
public:
|
||||
template<class, class> friend class handle_interface;
|
||||
public:
|
||||
//Constructors
|
||||
handle(CUType cu = CUType(), bool take_ownership = true);
|
||||
handle(T h, bool take_ownership = true);
|
||||
handle();
|
||||
~handle();
|
||||
CUType& operator*() { return *h_; }
|
||||
CUType const & operator*() const { return *h_; }
|
||||
CUType* operator->() const { return h_.get(); }
|
||||
T& operator*() { return *h_; }
|
||||
T const & operator*() const { return *h_; }
|
||||
T* operator->() const { return h_.get(); }
|
||||
|
||||
protected:
|
||||
std::shared_ptr<CUType> h_;
|
||||
std::shared_ptr<T> h_;
|
||||
bool has_ownership_;
|
||||
};
|
||||
|
||||
|
@@ -44,19 +44,19 @@ protected:
|
||||
};
|
||||
|
||||
/* Metaparameter int */
|
||||
class metaparameter: public constant_int{
|
||||
metaparameter(type *ty, unsigned lo, unsigned hi);
|
||||
class metaparameter: public constant_int {
|
||||
private:
|
||||
metaparameter(type *ty, const std::vector<unsigned>& space);
|
||||
|
||||
public:
|
||||
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
|
||||
static metaparameter *create(context &ctx, type *ty, const std::vector<unsigned>& space);
|
||||
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
|
||||
bool has_value() { return has_value_; }
|
||||
unsigned get_lo() { return lo_; }
|
||||
unsigned get_hi() { return hi_; }
|
||||
const std::vector<unsigned>& get_space() { return space_; }
|
||||
|
||||
private:
|
||||
unsigned lo_;
|
||||
unsigned hi_;
|
||||
std::vector<unsigned> space_;
|
||||
bool has_value_;
|
||||
};
|
||||
|
||||
|
@@ -410,12 +410,16 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
std::string name = decl_->id()->name();
|
||||
ir::value *value = ir::undef_value::get(ty);
|
||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||
assert(expr_ == nullptr);
|
||||
//TODO: implement ranges
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64);
|
||||
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
|
||||
if(csts == nullptr)
|
||||
throw std::runtime_error("must specify constant list for metaparameters");
|
||||
std::vector<unsigned> values;
|
||||
for(constant* cst: csts->values())
|
||||
values.push_back(cst->value());
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, values);
|
||||
mod->register_global(name, value);
|
||||
}
|
||||
if(expr_){
|
||||
else if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
value = explicit_cast(mod->get_builder(), value, ty);
|
||||
implicit_broadcast(mod, value, ty);
|
||||
|
@@ -144,11 +144,23 @@ void tune::run(ir::module &mod) {
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2);
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
|
||||
// Simplify metaparameters
|
||||
std::set<ir::metaparameter*> fixed_io_nts;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(dynamic_cast<ir::load_inst*>(i) || dynamic_cast<ir::store_inst*>(i))
|
||||
if(i->get_type()->is_tile_ty())
|
||||
for(unsigned d = 1; d < i->get_type()->get_tile_shapes().size(); d++)
|
||||
fixed_io_nts.insert(params_.at(i).at("nts.d" + std::to_string(d)));
|
||||
for(ir::metaparameter* mp: fixed_io_nts)
|
||||
mp->set_value(1);
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
@@ -234,7 +246,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
int num_threads = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||
if(num_threads % 32 != 0)
|
||||
if(num_threads % 64 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
|
@@ -25,7 +25,7 @@
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#include "triton/driver/helpers/CL/infos.hpp"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/context.h"
|
||||
|
||||
@@ -40,6 +40,14 @@ namespace driver
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
// maximum amount of shared memory per block
|
||||
size_t ocl_device::max_shared_memory() const {
|
||||
return ocl::info<CL_DEVICE_LOCAL_MEM_SIZE>(*cl_);
|
||||
}
|
||||
|
||||
size_t ocl_device::max_threads_per_block() const {
|
||||
return ocl::info<CL_DEVICE_MAX_WORK_ITEM_SIZES>(*cl_).at(0);
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
|
@@ -60,13 +60,16 @@ inline void _delete(cu_event_t x) { _delete(x.first); _delete(x.second); }
|
||||
inline void _delete(CUPlatform){}
|
||||
|
||||
//Constructor
|
||||
template<class CUType>
|
||||
handle<CUType>::handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership)
|
||||
template<class T>
|
||||
handle<T>::handle(T cu, bool take_ownership): h_(new T(cu)), has_ownership_(take_ownership)
|
||||
{ }
|
||||
|
||||
template<class T>
|
||||
handle<T>::handle(): has_ownership_(false){ }
|
||||
|
||||
template<class CUType>
|
||||
handle<CUType>::~handle(){
|
||||
|
||||
template<class T>
|
||||
handle<T>::~handle(){
|
||||
if(has_ownership_ && h_ && h_.unique())
|
||||
_delete(*h_);
|
||||
}
|
||||
|
@@ -53,6 +53,10 @@
|
||||
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
|
||||
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "lld/Common/Driver.h"
|
||||
#include "lld/Common/Args.h"
|
||||
#include "lld/Common/ErrorHandler.h"
|
||||
#include "lld/Common/LLVM.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
@@ -110,36 +114,17 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
||||
std::string error;
|
||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
// opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
// opt.UnsafeFPMath = false;
|
||||
// opt.NoInfsFPMath = false;
|
||||
// opt.NoNaNsFPMath = true;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "code-object-v3", opt,
|
||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
|
||||
// set data layout
|
||||
if(layout.empty())
|
||||
module->setDataLayout(machine->createDataLayout());
|
||||
else
|
||||
module->setDataLayout(layout);
|
||||
|
||||
// link
|
||||
for (std::string& path: paths) {
|
||||
llvm::SMDiagnostic err;
|
||||
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, module->getContext());
|
||||
if (mlib.get() == nullptr) {
|
||||
std::string msg = err.getMessage();
|
||||
std::cerr << "Fail to load bitcode file " << path << "\n"
|
||||
<< "line " << err.getLineNo() << ":" << msg;
|
||||
}
|
||||
mlib->setTargetTriple(module->getTargetTriple());
|
||||
mlib->setDataLayout(module->getDataLayout());
|
||||
for (llvm::Function &f : mlib->functions()) {
|
||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||
}
|
||||
llvm::Linker::linkModules(*module, std::move(mlib));
|
||||
}
|
||||
|
||||
// emit machine code
|
||||
for (llvm::Function &f : module->functions())
|
||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||
@@ -187,12 +172,10 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
|
||||
init_llvm();
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
|
||||
|
||||
std::ofstream output("tmp.o", std::ios::binary);
|
||||
std::ofstream output("/tmp/tmp.o", std::ios::binary);
|
||||
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
|
||||
system("ld.lld tmp.o -shared -o test.o");
|
||||
|
||||
std::ifstream input("test.o", std::ios::in | std::ios::binary );
|
||||
system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o");
|
||||
std::ifstream input("/tmp/tmp.o", std::ios::in | std::ios::binary );
|
||||
std::vector<unsigned char> in_buffer(std::istreambuf_iterator<char>(input), {});
|
||||
size_t sizes[] = {in_buffer.size()};
|
||||
const unsigned char* data[] = {(unsigned char*)in_buffer.data()};
|
||||
@@ -208,7 +191,6 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
|
||||
char log[2048];
|
||||
dispatch::clGetProgramBuildInfo(*cl_, *context->device()->cl(), CL_PROGRAM_BUILD_LOG, 1024, log, NULL);
|
||||
std::cout << log << std::endl;
|
||||
std::cout << "T_T" << std::endl;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
@@ -111,7 +111,8 @@ void cl_stream::synchronize() {
|
||||
}
|
||||
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
|
||||
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)grid.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
||||
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
|
||||
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
||||
}
|
||||
|
||||
void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
|
@@ -98,12 +98,22 @@ constant *constant_fp::get(context &ctx, double v){
|
||||
}
|
||||
|
||||
// metaparameter
|
||||
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
|
||||
: constant_int(ty, 0), lo_(lo), hi_(hi), has_value_(false){ }
|
||||
metaparameter::metaparameter(type *ty, const std::vector<unsigned> &space)
|
||||
: constant_int(ty, 0), space_(space), has_value_(false){ }
|
||||
|
||||
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
|
||||
context_impl *impl = ctx.p_impl.get();
|
||||
metaparameter *result = new metaparameter(ty, lo, hi);
|
||||
std::vector<unsigned> space;
|
||||
for(unsigned i = lo; i <= hi; i *= 2)
|
||||
space.push_back(i);
|
||||
metaparameter *result = new metaparameter(ty, space);
|
||||
impl->mp_constants_.push_back(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<unsigned> &space) {
|
||||
context_impl *impl = ctx.p_impl.get();
|
||||
metaparameter *result = new metaparameter(ty, space);
|
||||
impl->mp_constants_.push_back(result);
|
||||
return result;
|
||||
}
|
||||
|
33
lib/jit.cpp
33
lib/jit.cpp
@@ -5,6 +5,7 @@
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
@@ -71,6 +72,7 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
auto &launch_info_map = launch_info_map_[result->getName()];
|
||||
launch_info_map.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
launch_info_map.num_threads = passes.tune.get_num_threads();
|
||||
@@ -104,12 +106,8 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps){
|
||||
std::vector<unsigned> current;
|
||||
for(unsigned x = mp->get_lo(); x <= mp->get_hi(); x*=2)
|
||||
current.push_back(x);
|
||||
ranges.push_back(current);
|
||||
}
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
double best = 0;
|
||||
@@ -132,22 +130,23 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
// driver::device* device = driver_context_->device();
|
||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
// return;
|
||||
// if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
// return;
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
driver::module* module = driver::module::create(driver_context_, &*ll_module);
|
||||
driver::kernel* kernel = driver::kernel::create(module, "matmul");
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
|
||||
launch_information info = launch_info_map_.at("matmul");
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
double perf = benchmark(kernel, info);
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
best = std::max(perf, best);
|
||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||
});
|
||||
@@ -167,9 +166,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
passes.tune.check_constraints(errors);
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// driver::device* device = driver_context_->device();
|
||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
// throw std::runtime_error("invalid parameters");
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
|
Reference in New Issue
Block a user