[codegen] more work on hmma coalescing
This commit is contained in:
@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
|||||||
double total_time = 0;
|
double total_time = 0;
|
||||||
op();
|
op();
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
while(total_time*1e-9 < 1e-3){
|
while(total_time*1e-9 < 1e-1){
|
||||||
float norm = 1;
|
float norm = 1;
|
||||||
// normalize clock if possible to reduce noise in auto-tuning
|
// normalize clock if possible to reduce noise in auto-tuning
|
||||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||||
|
@@ -487,6 +487,9 @@ void align::populate(ir::value *v) {
|
|||||||
populate_is_constant(v);
|
populate_is_constant(v);
|
||||||
populate_starting_multiple(v);
|
populate_starting_multiple(v);
|
||||||
populate_max_contiguous(v);
|
populate_max_contiguous(v);
|
||||||
|
// std::cout << v->get_name() << std::endl;
|
||||||
|
// if(max_contiguous_[v].size() == 2)
|
||||||
|
// std::cout << max_contiguous_[v][0] << " " << max_contiguous_[v][1] << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void align::run(ir::module &mod) {
|
void align::run(ir::module &mod) {
|
||||||
|
@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
|
// std::cout << source << std::endl;
|
||||||
cu_context::context_switcher ctx_switch(*context);
|
cu_context::context_switcher ctx_switch(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -161,6 +161,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
|||||||
for(auto it: opt_space_.defines)
|
for(auto it: opt_space_.defines)
|
||||||
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
||||||
cpp.Process(tokens);
|
cpp.Process(tokens);
|
||||||
|
// tokens.Print(stdout);
|
||||||
// parse
|
// parse
|
||||||
Parser parser(tokens);
|
Parser parser(tokens);
|
||||||
parser.Parse();
|
parser.Parse();
|
||||||
@@ -215,16 +216,19 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
// run passes
|
// run passes
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
align.run(module);
|
align.run(module);
|
||||||
cts.run(module);
|
cts.run(module);
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
layouts.run(module);
|
layouts.run(module);
|
||||||
coalesce.run(module);
|
coalesce.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
align.run(module);
|
align.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
tiles.run(module);
|
tiles.run(module);
|
||||||
reassociate.run(module);
|
reassociate.run(module);
|
||||||
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
cts.run(module);
|
cts.run(module);
|
||||||
liveness.run(module);
|
liveness.run(module);
|
||||||
|
@@ -48,7 +48,7 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
opt.defines.push_back({"TK", {"16"}});
|
opt.defines.push_back({"TK", {"16"}});
|
||||||
opt.num_warps = {2, 4, 8};
|
opt.num_warps = {4};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
// benchmark available libraries
|
// benchmark available libraries
|
||||||
@@ -79,12 +79,9 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||||
{false, true},
|
|
||||||
{true, false},
|
|
||||||
{true, true}}){
|
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
config_t{x[0], x[1], 4096, 4096, 4096}
|
||||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||||
// config_t{x[0], x[1], 32, 2048, 2048},
|
// config_t{x[0], x[1], 32, 2048, 2048},
|
||||||
// config_t{x[0], x[1], 64, 2048, 2048},
|
// config_t{x[0], x[1], 64, 2048, 2048},
|
||||||
|
Reference in New Issue
Block a user