[codegen] more work on hmma coalescing
This commit is contained in:
@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
double total_time = 0;
|
||||
op();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
while(total_time*1e-9 < 1e-1){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
|
@@ -487,6 +487,9 @@ void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
// std::cout << v->get_name() << std::endl;
|
||||
// if(max_contiguous_[v].size() == 2)
|
||||
// std::cout << max_contiguous_[v][0] << " " << max_contiguous_[v][1] << std::endl;
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
|
@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -161,6 +161,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
for(auto it: opt_space_.defines)
|
||||
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
||||
cpp.Process(tokens);
|
||||
// tokens.Print(stdout);
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
@@ -215,16 +216,19 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
align.run(module);
|
||||
cts.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
coalesce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
dce.run(module);
|
||||
tiles.run(module);
|
||||
reassociate.run(module);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
cts.run(module);
|
||||
liveness.run(module);
|
||||
|
@@ -48,7 +48,7 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
opt.num_warps = {4};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
@@ -79,12 +79,9 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
||||
{false, true},
|
||||
{true, false},
|
||||
{true, true}}){
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||
config_t{x[0], x[1], 4096, 4096, 4096}
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
// config_t{x[0], x[1], 32, 2048, 2048},
|
||||
// config_t{x[0], x[1], 64, 2048, 2048},
|
||||
|
Reference in New Issue
Block a user