[FRONTEND] Added compilation flag to force use of .nc cache modifier (#134)
				
					
				
			in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED /!\
This commit is contained in:
		
				
					committed by
					
						 Philippe Tillet
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							2824345065
						
					
				
				
					commit
					01276b5153
				
			| @@ -21,7 +21,7 @@ namespace codegen{ | ||||
|  | ||||
| // TODO: | ||||
| // There should be a proper pass manager there! | ||||
| void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, | ||||
| void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache, | ||||
|                             driver::module*& mod, driver::kernel*& ker, size_t& shared_mem); | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -122,7 +122,8 @@ public: | ||||
|             analysis::allocation *alloc, | ||||
|             analysis::swizzle *swizzle, | ||||
|             target *tgt, | ||||
|             unsigned num_warps); | ||||
|             unsigned num_warps, | ||||
|             bool force_nc_cache = false); | ||||
|  | ||||
|   void visit_value(ir::value* v); | ||||
|   void visit_phi_node(ir::phi_node*); | ||||
| @@ -208,9 +209,11 @@ private: | ||||
|   analysis::align *alignment_; | ||||
|   analysis::allocation *alloc_; | ||||
|   Value *shmem_; | ||||
|   unsigned num_warps_; | ||||
|   std::set<ir::value*> seen_; | ||||
|  | ||||
|   unsigned num_warps_; | ||||
|   bool force_nc_cache_; | ||||
|  | ||||
|   std::map<analysis::data_layout*, Value*> offset_a_m_; | ||||
|   std::map<analysis::data_layout*, Value*> offset_a_k_; | ||||
|   std::map<analysis::data_layout*, Value*> offset_b_k_; | ||||
|   | ||||
| @@ -26,7 +26,7 @@ namespace codegen { | ||||
|  | ||||
| // TODO: | ||||
| // There should be a proper pass manager there! | ||||
| void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, | ||||
| void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache, | ||||
|                             driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) { | ||||
|   // generate llvm code | ||||
|   llvm::LLVMContext ctx; | ||||
| @@ -51,7 +51,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, | ||||
|   codegen::transform::coalesce coalesce(&align, &layouts); | ||||
|   codegen::transform::prefetch prefetch_s(target.get()); | ||||
|   codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get()); | ||||
|   codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); | ||||
|   codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache); | ||||
|   // run passes | ||||
|   dce.run(ir); | ||||
|   peephole.run(ir); | ||||
|   | ||||
| @@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes, | ||||
|                     analysis::allocation *alloc, | ||||
|                     analysis::swizzle *swizzle, | ||||
|                     target *tgt, | ||||
|                     unsigned num_warps) | ||||
|                     unsigned num_warps, bool force_nc_cache) | ||||
|   : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), | ||||
|     tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { | ||||
|     tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) { | ||||
|  | ||||
| } | ||||
|  | ||||
| @@ -626,7 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){ | ||||
|     // ----- | ||||
|     std::ostringstream asm_oss; | ||||
|     asm_oss << "@$" << n_words; // predicate | ||||
|     asm_oss << " ld.global.cg"; | ||||
|     if(force_nc_cache_) | ||||
|       asm_oss << " ld.global.nc"; | ||||
|     else | ||||
|       asm_oss << " ld.global.cg"; | ||||
|     if(n_words > 1) | ||||
|       asm_oss << ".v" << n_words; // vector width | ||||
|     asm_oss << ".b" << width; // word size | ||||
|   | ||||
| @@ -78,11 +78,11 @@ void init_triton_driver(py::module &&m) { | ||||
|  | ||||
| void init_triton_codegen(py::module &&m) { | ||||
|   m.def( | ||||
|       "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) { | ||||
|       "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) { | ||||
|         drv::module *mod; | ||||
|         drv::kernel *ker; | ||||
|         size_t shared_mem; | ||||
|         triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem); | ||||
|         triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem); | ||||
|         std::stringstream ss; | ||||
|         ir::print(ir, ss); | ||||
|         return std::make_tuple(mod, ker, shared_mem, ss.str()); | ||||
|   | ||||
| @@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor): | ||||
|  | ||||
|  | ||||
| class Binary: | ||||
|     def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm): | ||||
|     def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm): | ||||
|         # cache ir asm | ||||
|         self.ir_asm = ir_asm | ||||
|         self.module = module | ||||
| @@ -416,6 +416,7 @@ class Binary: | ||||
|         self.shared_mem = shared_mem | ||||
|         self.num_warps = num_warps | ||||
|         self.num_stages = num_stages | ||||
|         self.force_nc_cache = force_nc_cache | ||||
|         self.sass = None | ||||
|  | ||||
|     def asm(self, mode): | ||||
| @@ -524,7 +525,7 @@ class Kernel: | ||||
|     def __init__(self, fn): | ||||
|         self.fn = fn | ||||
|  | ||||
|     def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): | ||||
|     def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta): | ||||
|         # explicitly set device | ||||
|         torch.cuda.set_device(device.index) | ||||
|         # create IR module | ||||
| @@ -546,12 +547,12 @@ class Kernel: | ||||
|             raise CompilationError(self.fn.src, node, e) | ||||
|         tt_device = _triton.driver.cu_device(device.index, False) | ||||
|         # Compile to machine code | ||||
|         mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages) | ||||
|         mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache) | ||||
|         if shared_mem > tt_device.max_shared_memory(): | ||||
|             raise  OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory") | ||||
|         return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm) | ||||
|         return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm) | ||||
|  | ||||
|     def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): | ||||
|     def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): | ||||
|         # device inference | ||||
|         tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] | ||||
|         if len(tensor_idxs) == 0: | ||||
| @@ -573,7 +574,9 @@ class Kernel: | ||||
|         if key not in cache: | ||||
|             # compile and cache configuration if necessary | ||||
|             cache[key] = self._compile( | ||||
|                 *wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta | ||||
|                 *wargs, device=device, attributes=attributes,  | ||||
|                 num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,  | ||||
|                 constants=constants, **meta | ||||
|             ) | ||||
|         # pack arguments | ||||
|         fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) | ||||
|   | ||||
| @@ -183,7 +183,7 @@ class _softmax(torch.autograd.Function): | ||||
|         } | ||||
|         grid = lambda opt: [spdims[0] * spdims[1] * block, M] | ||||
|         _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ | ||||
|                        stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta) | ||||
|                        stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta) | ||||
|  | ||||
|         # save to context | ||||
|         ctx.mark_dirty(x) | ||||
| @@ -207,7 +207,7 @@ class _softmax(torch.autograd.Function): | ||||
|         # run kernel | ||||
|         M = x.shape[0] | ||||
|         grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] | ||||
|         _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) | ||||
|         _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block) | ||||
|         return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user