[GENERAL] Removed deprecated driver files and added basic compatibility with rocm (#268)

- Removed driver module -- accelerator runtime is handled by pytorch
- Added basic support for ROCM based on @micmelesse 's PR -- now can execute empty kernel on AMD devices without any compile-time changes
- Now only using PREFER_SHARED for kernels when the size of shared memory is greater than 49k. Otherwise there can be poor L1 performance for broadcast tensors
This commit is contained in:
Philippe Tillet
2021-09-09 00:04:28 -07:00
committed by GitHub
parent 8bedcce9be
commit 94c83d30ce
47 changed files with 1376 additions and 30232 deletions

View File

@@ -13,45 +13,40 @@
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/driver/device.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace triton {
namespace codegen {
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache,
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
int cc, int num_warps, int num_stages, bool force_nc_cache, int& shared_static) {
// generate llvm code
llvm::LLVMContext ctx;
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
std::unique_ptr<codegen::target> target = dev->make_target();
bool cts_use_async = target->as_nvidia()->sm() >= 80;
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::swizzle swizzle(&layouts, target);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
// codegen::transform::reassociate reassociate;
codegen::transform::peephole peephole(target, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target.get());
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache);
codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps, force_nc_cache);
// run passes
dce.run(ir);
peephole.run(ir);
@@ -72,15 +67,12 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
// exit(1);
align.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
dce.run(ir);
align.run(ir);
// ir::print(ir, std::cout);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
@@ -93,11 +85,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
allocation.run(ir);
prefetch_s.run(ir);
barriers.run(ir);
// ir.print(std::cout);
isel.visit(ir, *llvm);
mod = driver::module::create(dev, std::move(llvm));
ker = driver::kernel::create(&*mod, name.c_str());
shared_mem = allocation.allocated_size();
shared_static = allocation.allocated_size();
return llvm;
}
} // namespace codegen