From 01276b515366f624ebc13ec74b81c7e7f5cc07eb Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 20 Jul 2021 17:58:06 -0400 Subject: [PATCH] [FRONTEND] Added compilation flag to force use of `.nc` cache modifier (#134) in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED /!\ --- include/triton/codegen/pass.h | 2 +- include/triton/codegen/selection/generator.h | 7 +++++-- lib/codegen/pass.cc | 4 ++-- lib/codegen/selection/generator.cc | 9 ++++++--- python/src/triton.cc | 4 ++-- python/triton/code_gen.py | 15 +++++++++------ python/triton/ops/blocksparse/softmax.py | 4 ++-- 7 files changed, 27 insertions(+), 18 deletions(-) diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h index 94ba25a0f..ed37e0868 100644 --- a/include/triton/codegen/pass.h +++ b/include/triton/codegen/pass.h @@ -21,7 +21,7 @@ namespace codegen{ // TODO: // There should be a proper pass manager there! -void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, +void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache, driver::module*& mod, driver::kernel*& ker, size_t& shared_mem); diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 8850b9f80..cd091d821 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -122,7 +122,8 @@ public: analysis::allocation *alloc, analysis::swizzle *swizzle, target *tgt, - unsigned num_warps); + unsigned num_warps, + bool force_nc_cache = false); void visit_value(ir::value* v); void visit_phi_node(ir::phi_node*); @@ -208,9 +209,11 @@ private: analysis::align *alignment_; analysis::allocation *alloc_; Value *shmem_; - unsigned num_warps_; std::set seen_; + unsigned num_warps_; + bool force_nc_cache_; + std::map offset_a_m_; std::map offset_a_k_; std::map offset_b_k_; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index a88f4259e..ba4547d10 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -26,7 +26,7 @@ namespace codegen { // TODO: // There should be a proper pass manager there! -void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, +void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache, driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) { // generate llvm code llvm::LLVMContext ctx; @@ -51,7 +51,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::prefetch prefetch_s(target.get()); codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get()); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache); // run passes dce.run(ir); peephole.run(ir); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a6d73e030..47b96aa08 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes, analysis::allocation *alloc, analysis::swizzle *swizzle, target *tgt, - unsigned num_warps) + unsigned num_warps, bool force_nc_cache) : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), - tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { + tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) { } @@ -626,7 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){ // ----- std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate - asm_oss << " ld.global.cg"; + if(force_nc_cache_) + asm_oss << " ld.global.nc"; + else + asm_oss << " ld.global.cg"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size diff --git a/python/src/triton.cc b/python/src/triton.cc index 6792d7c9c..47db99998 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -78,11 +78,11 @@ void init_triton_driver(py::module &&m) { void init_triton_codegen(py::module &&m) { m.def( - "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) { + "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) { drv::module *mod; drv::kernel *ker; size_t shared_mem; - triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem); + triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem); std::stringstream ss; ir::print(ir, ss); return std::make_tuple(mod, ker, shared_mem, ss.str()); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 5b0863814..c0743c268 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor): class Binary: - def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm): + def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm): # cache ir asm self.ir_asm = ir_asm self.module = module @@ -416,6 +416,7 @@ class Binary: self.shared_mem = shared_mem self.num_warps = num_warps self.num_stages = num_stages + self.force_nc_cache = force_nc_cache self.sass = None def asm(self, mode): @@ -524,7 +525,7 @@ class Kernel: def __init__(self, fn): self.fn = fn - def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): + def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta): # explicitly set device torch.cuda.set_device(device.index) # create IR module @@ -546,12 +547,12 @@ class Kernel: raise CompilationError(self.fn.src, node, e) tt_device = _triton.driver.cu_device(device.index, False) # Compile to machine code - mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages) + mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache) if shared_mem > tt_device.max_shared_memory(): raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory") - return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm) + return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): # device inference tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] if len(tensor_idxs) == 0: @@ -573,7 +574,9 @@ class Kernel: if key not in cache: # compile and cache configuration if necessary cache[key] = self._compile( - *wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta + *wargs, device=device, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, + constants=constants, **meta ) # pack arguments fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index e7fbe1fd8..412530722 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -183,7 +183,7 @@ class _softmax(torch.autograd.Function): } grid = lambda opt: [spdims[0] * spdims[1] * block, M] _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ - stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta) + stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta) # save to context ctx.mark_dirty(x) @@ -207,7 +207,7 @@ class _softmax(torch.autograd.Function): # run kernel M = x.shape[0] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] - _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) + _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block) return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None