[code generation] search space pruning

This commit is contained in:
Philippe Tillet
2019-03-25 14:10:24 -07:00
parent deb7a1cc5c
commit 8d35c98920
14 changed files with 131 additions and 118 deletions

View File

@@ -6,9 +6,9 @@
const char* src =
R"(
const tunable int32 TM;
const tunable int32 TN;
const tunable int32 TK;
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8, 16};
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
int32 M, int32 N, int32 K, int32 bound){
@@ -26,20 +26,8 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
pa = pa + TK*M;
pb = pb + TK*K;
k = k - TK;
int1 checka[TM, TK] = k > bound;
int1 checkb[TN, TK] = k > bound;
@checka a = *pa;
@checkb b = *pb;
if(k > bound)
continue;
int1 checka0[TM] = rxa < M;
int1 checka1[TK] = rka < k;
int1 checkb0[TN] = ryb < N;
int1 checkb1[TK] = rkb < k;
checka = checka0[:, newaxis] && checka1[newaxis, :];
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
a = *pa;
b = *pb;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
@@ -87,22 +75,17 @@ T min(std::vector<T> x)
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync)
double bench(OP const & op, SYNC const & sync, unsigned repeat = 20)
{
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
while(total_time*1e-9 < 1e-3){
float norm = 1;
tmr.start();
tmr.start();
for(unsigned i = 0; i < repeat; i++)
op();
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return min(times);
sync();
double time = tmr.get().count();
return time / repeat;
}
int main() {
@@ -111,16 +94,16 @@ int main() {
triton::jit jit(context);
// matrix multiplication parameters
int32_t M = 128, N = 128, K = 128;
int32_t M = 512, N = 512, K = 512;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
std::vector<float> hb(K*N);
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = 1;
ha[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = 1;
hb[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
@@ -163,11 +146,10 @@ int main() {
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
// benchmark
// double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
// [&](){ stream->synchronize(); });
double ts = 1;
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); });
ts = ts * 1e-9;
double tflops = 2*M*N*K / ts * 1e-12;
double tflops = 2.*M*N*K / ts * 1e-12;
return tflops;
};
@@ -177,11 +159,12 @@ int main() {
16, 2, 64,
32, 2, 64,
16, 8, 2, 2,
8, 1, 8,
4, 1
8, 8,
4,
};
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
// jit.autotune(src, benchmark);
jit.autotune(src, benchmark);
jit.add_module(src, params);
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");

View File

@@ -94,10 +94,15 @@ abstract_declarator
direct_abstract_declarator
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
constant :
constant:
CONSTANT { $$ = new constant(atoi(yytext)); }
;
constant_list:
constant { $$ = new list<constant*>((constant*)$1); }
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
;
type_name
: declaration_specifiers { $$ = new type_name($1, nullptr); }
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
@@ -259,7 +264,7 @@ expression
/* Initialization */
initialization_expression
: assignment_expression { $$ = $1; }
| '{' constant '}' { $$ = $2; }
| '{' constant_list '}' { $$ = $2; }
;

View File

@@ -38,18 +38,24 @@ class context;
class device: public polymorphic_resource<CUdevice, cl_device_id, host_device_t>{
public:
using polymorphic_resource::polymorphic_resource;
virtual size_t max_threads_per_block() const = 0;
virtual size_t max_shared_memory() const = 0;
};
// Host device
class host_device: public device {
public:
host_device(): device(host_device_t(), true){ }
size_t max_threads_per_block() const { return 1; }
size_t max_shared_memory() const { return 0; }
};
// OpenCL device
class ocl_device: public device {
public:
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
};
// CUDA device
@@ -87,8 +93,6 @@ public:
std::string infos() const;
size_t address_bits() const;
std::vector<size_t> max_block_dim() const;
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
size_t warp_size() const;
//Compute Capability
void interpret_as(std::pair<size_t, size_t> cc);
@@ -99,7 +103,8 @@ public:
//Clocks
size_t current_sm_clock() const;
size_t current_mem_clock() const;
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
size_t max_sm_clock() const;
size_t max_mem_clock() const;

View File

@@ -87,7 +87,7 @@ public:
static bool cudnninit();
static void release();
//OpenCL
// OpenCL
static cl_int clBuildProgram(cl_program, cl_uint, const cl_device_id *, const char *, void (*)(cl_program, void *), void *);
static cl_int clEnqueueNDRangeKernel(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *);
static cl_int clSetKernelArg(cl_kernel, cl_uint, size_t, const void *);

View File

@@ -105,20 +105,21 @@ public:
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
};
template<class CUType>
template<class T>
class handle{
public:
template<class, class> friend class handle_interface;
public:
//Constructors
handle(CUType cu = CUType(), bool take_ownership = true);
handle(T h, bool take_ownership = true);
handle();
~handle();
CUType& operator*() { return *h_; }
CUType const & operator*() const { return *h_; }
CUType* operator->() const { return h_.get(); }
T& operator*() { return *h_; }
T const & operator*() const { return *h_; }
T* operator->() const { return h_.get(); }
protected:
std::shared_ptr<CUType> h_;
std::shared_ptr<T> h_;
bool has_ownership_;
};

View File

@@ -44,19 +44,19 @@ protected:
};
/* Metaparameter int */
class metaparameter: public constant_int{
metaparameter(type *ty, unsigned lo, unsigned hi);
class metaparameter: public constant_int {
private:
metaparameter(type *ty, const std::vector<unsigned>& space);
public:
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
static metaparameter *create(context &ctx, type *ty, const std::vector<unsigned>& space);
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
bool has_value() { return has_value_; }
unsigned get_lo() { return lo_; }
unsigned get_hi() { return hi_; }
const std::vector<unsigned>& get_space() { return space_; }
private:
unsigned lo_;
unsigned hi_;
std::vector<unsigned> space_;
bool has_value_;
};

View File

@@ -410,12 +410,16 @@ ir::value* initializer::codegen(ir::module * mod) const{
std::string name = decl_->id()->name();
ir::value *value = ir::undef_value::get(ty);
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
assert(expr_ == nullptr);
//TODO: implement ranges
value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64);
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
if(csts == nullptr)
throw std::runtime_error("must specify constant list for metaparameters");
std::vector<unsigned> values;
for(constant* cst: csts->values())
values.push_back(cst->value());
value = ir::metaparameter::create(mod->get_context(), ty, values);
mod->register_global(name, value);
}
if(expr_){
else if(expr_){
value = expr_->codegen(mod);
value = explicit_cast(mod->get_builder(), value, ty);
implicit_broadcast(mod, value, ty);

View File

@@ -144,11 +144,23 @@ void tune::run(ir::module &mod) {
// Layout parameters
while(!nodes_.empty()){
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2);
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
}
}
// Simplify metaparameters
std::set<ir::metaparameter*> fixed_io_nts;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(dynamic_cast<ir::load_inst*>(i) || dynamic_cast<ir::store_inst*>(i))
if(i->get_type()->is_tile_ty())
for(unsigned d = 1; d < i->get_type()->get_tile_shapes().size(); d++)
fixed_io_nts.insert(params_.at(i).at("nts.d" + std::to_string(d)));
for(ir::metaparameter* mp: fixed_io_nts)
mp->set_value(1);
}
void tune::init(ir::module &mod) {
@@ -234,7 +246,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
int num_threads = 1;
for(size_t k = 0; k < shapes.size(); k++)
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
if(num_threads % 32 != 0)
if(num_threads % 64 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");

View File

@@ -25,7 +25,7 @@
#include <sstream>
#include <cstring>
#include <memory>
#include "triton/driver/helpers/CL/infos.hpp"
#include "triton/driver/device.h"
#include "triton/driver/context.h"
@@ -40,6 +40,14 @@ namespace driver
// OpenCL //
/* ------------------------ */
// maximum amount of shared memory per block
size_t ocl_device::max_shared_memory() const {
return ocl::info<CL_DEVICE_LOCAL_MEM_SIZE>(*cl_);
}
size_t ocl_device::max_threads_per_block() const {
return ocl::info<CL_DEVICE_MAX_WORK_ITEM_SIZES>(*cl_).at(0);
}
/* ------------------------ */
// CUDA //

View File

@@ -60,13 +60,16 @@ inline void _delete(cu_event_t x) { _delete(x.first); _delete(x.second); }
inline void _delete(CUPlatform){}
//Constructor
template<class CUType>
handle<CUType>::handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership)
template<class T>
handle<T>::handle(T cu, bool take_ownership): h_(new T(cu)), has_ownership_(take_ownership)
{ }
template<class T>
handle<T>::handle(): has_ownership_(false){ }
template<class CUType>
handle<CUType>::~handle(){
template<class T>
handle<T>::~handle(){
if(has_ownership_ && h_ && h_.unique())
_delete(*h_);
}

View File

@@ -53,6 +53,10 @@
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include "llvm/Transforms/Utils/Cloning.h"
#include "lld/Common/Driver.h"
#include "lld/Common/Args.h"
#include "lld/Common/ErrorHandler.h"
#include "lld/Common/LLVM.h"
namespace triton
{
@@ -110,36 +114,17 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
// opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
// opt.UnsafeFPMath = false;
// opt.NoInfsFPMath = false;
// opt.NoNaNsFPMath = true;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "code-object-v3", opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// link
for (std::string& path: paths) {
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, module->getContext());
if (mlib.get() == nullptr) {
std::string msg = err.getMessage();
std::cerr << "Fail to load bitcode file " << path << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
mlib->setTargetTriple(module->getTargetTriple());
mlib->setDataLayout(module->getDataLayout());
for (llvm::Function &f : mlib->functions()) {
f.addFnAttr(llvm::Attribute::AlwaysInline);
}
llvm::Linker::linkModules(*module, std::move(mlib));
}
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
@@ -187,12 +172,10 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
init_llvm();
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer);
std::ofstream output("tmp.o", std::ios::binary);
std::ofstream output("/tmp/tmp.o", std::ios::binary);
std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
system("ld.lld tmp.o -shared -o test.o");
std::ifstream input("test.o", std::ios::in | std::ios::binary );
system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o");
std::ifstream input("/tmp/tmp.o", std::ios::in | std::ios::binary );
std::vector<unsigned char> in_buffer(std::istreambuf_iterator<char>(input), {});
size_t sizes[] = {in_buffer.size()};
const unsigned char* data[] = {(unsigned char*)in_buffer.data()};
@@ -208,7 +191,6 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
char log[2048];
dispatch::clGetProgramBuildInfo(*cl_, *context->device()->cl(), CL_PROGRAM_BUILD_LOG, 1024, log, NULL);
std::cout << log << std::endl;
std::cout << "T_T" << std::endl;
throw;
}
}

View File

@@ -111,7 +111,8 @@ void cl_stream::synchronize() {
}
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)grid.data(), (const size_t*)block.data(), 0, NULL, NULL));
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
}
void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {

View File

@@ -98,12 +98,22 @@ constant *constant_fp::get(context &ctx, double v){
}
// metaparameter
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
: constant_int(ty, 0), lo_(lo), hi_(hi), has_value_(false){ }
metaparameter::metaparameter(type *ty, const std::vector<unsigned> &space)
: constant_int(ty, 0), space_(space), has_value_(false){ }
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
context_impl *impl = ctx.p_impl.get();
metaparameter *result = new metaparameter(ty, lo, hi);
std::vector<unsigned> space;
for(unsigned i = lo; i <= hi; i *= 2)
space.push_back(i);
metaparameter *result = new metaparameter(ty, space);
impl->mp_constants_.push_back(result);
return result;
}
metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<unsigned> &space) {
context_impl *impl = ctx.p_impl.get();
metaparameter *result = new metaparameter(ty, space);
impl->mp_constants_.push_back(result);
return result;
}

View File

@@ -5,6 +5,7 @@
#include "triton/ir/context.h"
#include "triton/ir/context_impl.h"
#include "triton/driver/device.h"
#include "triton/driver/error.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
@@ -71,6 +72,7 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
passes.selection.run(module, *result);
// launch information
auto &launch_info_map = launch_info_map_[result->getName()];
launch_info_map.global_range_size.clear();
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
launch_info_map.num_threads = passes.tune.get_num_threads();
@@ -104,12 +106,8 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps){
std::vector<unsigned> current;
for(unsigned x = mp->get_lo(); x <= mp->get_hi(); x*=2)
current.push_back(x);
ranges.push_back(current);
}
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// iterate over parameters
unsigned i;
double best = 0;
@@ -132,22 +130,23 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
}
passes.tune.init(tt_module);
passes.init(tt_module);
// driver::device* device = driver_context_->device();
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
// return;
// if(passes.tune.get_num_threads() > device->max_threads_per_block())
// return;
driver::device* device = driver_context_->device();
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
return;
if(passes.tune.get_num_threads() > device->max_threads_per_block())
return;
// Compile
auto ll_module = make_llvm_module(tt_module, passes);
driver::module* module = driver::module::create(driver_context_, &*ll_module);
driver::kernel* kernel = driver::kernel::create(module, "matmul");
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
launch_information info = launch_info_map_.at("matmul");
for(unsigned p: params)
std::cout << p << " " << std::flush;
// add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
double perf = benchmark(kernel, info);
double perf;
perf = benchmark(kernel.get(), info);
best = std::max(perf, best);
std::cout << perf << " [ " << best << " ] " << std::endl;
});
@@ -167,9 +166,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
passes.tune.check_constraints(errors);
if(errors.size())
throw std::runtime_error("invalid parameters");
// driver::device* device = driver_context_->device();
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
// throw std::runtime_error("invalid parameters");
driver::device* device = driver_context_->device();
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
throw std::runtime_error("invalid parameters");
// triton module -> llvm module
auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code