Files
triton/include/triton/runtime/jit.h
2019-08-16 15:56:58 -07:00

137 lines
4.2 KiB
C++

#ifndef TDL_INCLUDE_JIT_H
#define TDL_INCLUDE_JIT_H
#include <string>
#include <memory>
#include "llvm/IR/LLVMContext.h"
#include "triton/ir/context.h"
#include "triton/ir/print.h"
#include "triton/driver/module.h"
#include "triton/driver/kernel.h"
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/selection/target.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/vectorize.h"
#include "triton/runtime/launch_info.h"
#include <functional>
namespace llvm {
class Module;
}
namespace triton {
namespace lang{
class translation_unit;
}
namespace codegen{
namespace analysis{
class tune;
}
}
namespace ir {
class module;
class context;
class metaparameter;
}
namespace runtime{
class jit {
public:
typedef std::function<double(driver::kernel*, launch_information)> benchmark_t;
struct tune_res_t{
double perf;
std::vector<unsigned> params;
};
struct passes_wrapper {
passes_wrapper(codegen::target* target)
: tune(0),
shmem_liveness(&shmem_info),
shmem_allocation(&shmem_liveness, &shmem_info, &tune),
shmem_barriers(&shmem_allocation, &shmem_info),
vectorize(&tune),
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
dce(),
peephole(),
alignment_info(),
reassociate(&tune),
target_(target) { }
void target_independent(ir::module &module) {
peephole.run(module);
dce.run(module);
}
void target_dependent(ir::module &module) {
reassociate.run(module);
peephole.run(module);
if(target_->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
shmem_allocation.run();
shmem_barriers.run(module);
}
alignment_info.run(module);
vectorize.run(module);
dce.run(module);
}
codegen::selection selection;
codegen::analysis::tune tune;
codegen::analysis::shmem::info shmem_info;
codegen::analysis::shmem::liveness shmem_liveness;
codegen::analysis::shmem::allocation shmem_allocation;
codegen::analysis::alignment_info alignment_info;
codegen::transform::shmem_barriers shmem_barriers;
codegen::transform::vectorize vectorize;
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate;
codegen::target* target_;
};
private:
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info);
std::unique_ptr<ir::module> make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program);
triton::lang::translation_unit *parse_program(const char *name, const char *src);
public:
jit(driver::context* context, unsigned nthreads = 4);
~jit();
std::vector<unsigned> get_valid(const char *name, const char *src);
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark, const std::vector<std::vector<unsigned> > &targets = {});
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
driver::kernel* get_function(const char* name);
launch_information get_launch_info(const char* name);
private:
std::map<std::string, driver::module*> modules_;
driver::context* driver_context_;
llvm::LLVMContext llvm_context_;
ir::context triton_context_;
std::map<std::string, launch_information> launch_info_map_;
std::shared_ptr<triton::codegen::target> target_;
unsigned nthreads_;
};
}
}
#endif