[FRONTEND] Added compilation flag to force use of .nc cache modifier (#134)
				
					
				
			in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED /!\
This commit is contained in:
		
				
					committed by
					
						 Philippe Tillet
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							2824345065
						
					
				
				
					commit
					01276b5153
				
			| @@ -21,7 +21,7 @@ namespace codegen{ | |||||||
|  |  | ||||||
| // TODO: | // TODO: | ||||||
| // There should be a proper pass manager there! | // There should be a proper pass manager there! | ||||||
| void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, | void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache, | ||||||
|                             driver::module*& mod, driver::kernel*& ker, size_t& shared_mem); |                             driver::module*& mod, driver::kernel*& ker, size_t& shared_mem); | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -122,7 +122,8 @@ public: | |||||||
|             analysis::allocation *alloc, |             analysis::allocation *alloc, | ||||||
|             analysis::swizzle *swizzle, |             analysis::swizzle *swizzle, | ||||||
|             target *tgt, |             target *tgt, | ||||||
|             unsigned num_warps); |             unsigned num_warps, | ||||||
|  |             bool force_nc_cache = false); | ||||||
|  |  | ||||||
|   void visit_value(ir::value* v); |   void visit_value(ir::value* v); | ||||||
|   void visit_phi_node(ir::phi_node*); |   void visit_phi_node(ir::phi_node*); | ||||||
| @@ -208,9 +209,11 @@ private: | |||||||
|   analysis::align *alignment_; |   analysis::align *alignment_; | ||||||
|   analysis::allocation *alloc_; |   analysis::allocation *alloc_; | ||||||
|   Value *shmem_; |   Value *shmem_; | ||||||
|   unsigned num_warps_; |  | ||||||
|   std::set<ir::value*> seen_; |   std::set<ir::value*> seen_; | ||||||
|  |  | ||||||
|  |   unsigned num_warps_; | ||||||
|  |   bool force_nc_cache_; | ||||||
|  |  | ||||||
|   std::map<analysis::data_layout*, Value*> offset_a_m_; |   std::map<analysis::data_layout*, Value*> offset_a_m_; | ||||||
|   std::map<analysis::data_layout*, Value*> offset_a_k_; |   std::map<analysis::data_layout*, Value*> offset_a_k_; | ||||||
|   std::map<analysis::data_layout*, Value*> offset_b_k_; |   std::map<analysis::data_layout*, Value*> offset_b_k_; | ||||||
|   | |||||||
| @@ -26,7 +26,7 @@ namespace codegen { | |||||||
|  |  | ||||||
| // TODO: | // TODO: | ||||||
| // There should be a proper pass manager there! | // There should be a proper pass manager there! | ||||||
| void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, | void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache, | ||||||
|                             driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) { |                             driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) { | ||||||
|   // generate llvm code |   // generate llvm code | ||||||
|   llvm::LLVMContext ctx; |   llvm::LLVMContext ctx; | ||||||
| @@ -51,7 +51,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, | |||||||
|   codegen::transform::coalesce coalesce(&align, &layouts); |   codegen::transform::coalesce coalesce(&align, &layouts); | ||||||
|   codegen::transform::prefetch prefetch_s(target.get()); |   codegen::transform::prefetch prefetch_s(target.get()); | ||||||
|   codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get()); |   codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get()); | ||||||
|   codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); |   codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache); | ||||||
|   // run passes |   // run passes | ||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|   peephole.run(ir); |   peephole.run(ir); | ||||||
|   | |||||||
| @@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes, | |||||||
|                     analysis::allocation *alloc, |                     analysis::allocation *alloc, | ||||||
|                     analysis::swizzle *swizzle, |                     analysis::swizzle *swizzle, | ||||||
|                     target *tgt, |                     target *tgt, | ||||||
|                     unsigned num_warps) |                     unsigned num_warps, bool force_nc_cache) | ||||||
|   : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), |   : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), | ||||||
|     tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { |     tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) { | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -626,7 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){ | |||||||
|     // ----- |     // ----- | ||||||
|     std::ostringstream asm_oss; |     std::ostringstream asm_oss; | ||||||
|     asm_oss << "@$" << n_words; // predicate |     asm_oss << "@$" << n_words; // predicate | ||||||
|     asm_oss << " ld.global.cg"; |     if(force_nc_cache_) | ||||||
|  |       asm_oss << " ld.global.nc"; | ||||||
|  |     else | ||||||
|  |       asm_oss << " ld.global.cg"; | ||||||
|     if(n_words > 1) |     if(n_words > 1) | ||||||
|       asm_oss << ".v" << n_words; // vector width |       asm_oss << ".v" << n_words; // vector width | ||||||
|     asm_oss << ".b" << width; // word size |     asm_oss << ".b" << width; // word size | ||||||
|   | |||||||
| @@ -78,11 +78,11 @@ void init_triton_driver(py::module &&m) { | |||||||
|  |  | ||||||
| void init_triton_codegen(py::module &&m) { | void init_triton_codegen(py::module &&m) { | ||||||
|   m.def( |   m.def( | ||||||
|       "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) { |       "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) { | ||||||
|         drv::module *mod; |         drv::module *mod; | ||||||
|         drv::kernel *ker; |         drv::kernel *ker; | ||||||
|         size_t shared_mem; |         size_t shared_mem; | ||||||
|         triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem); |         triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem); | ||||||
|         std::stringstream ss; |         std::stringstream ss; | ||||||
|         ir::print(ir, ss); |         ir::print(ir, ss); | ||||||
|         return std::make_tuple(mod, ker, shared_mem, ss.str()); |         return std::make_tuple(mod, ker, shared_mem, ss.str()); | ||||||
|   | |||||||
| @@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor): | |||||||
|  |  | ||||||
|  |  | ||||||
| class Binary: | class Binary: | ||||||
|     def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm): |     def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm): | ||||||
|         # cache ir asm |         # cache ir asm | ||||||
|         self.ir_asm = ir_asm |         self.ir_asm = ir_asm | ||||||
|         self.module = module |         self.module = module | ||||||
| @@ -416,6 +416,7 @@ class Binary: | |||||||
|         self.shared_mem = shared_mem |         self.shared_mem = shared_mem | ||||||
|         self.num_warps = num_warps |         self.num_warps = num_warps | ||||||
|         self.num_stages = num_stages |         self.num_stages = num_stages | ||||||
|  |         self.force_nc_cache = force_nc_cache | ||||||
|         self.sass = None |         self.sass = None | ||||||
|  |  | ||||||
|     def asm(self, mode): |     def asm(self, mode): | ||||||
| @@ -524,7 +525,7 @@ class Kernel: | |||||||
|     def __init__(self, fn): |     def __init__(self, fn): | ||||||
|         self.fn = fn |         self.fn = fn | ||||||
|  |  | ||||||
|     def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): |     def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta): | ||||||
|         # explicitly set device |         # explicitly set device | ||||||
|         torch.cuda.set_device(device.index) |         torch.cuda.set_device(device.index) | ||||||
|         # create IR module |         # create IR module | ||||||
| @@ -546,12 +547,12 @@ class Kernel: | |||||||
|             raise CompilationError(self.fn.src, node, e) |             raise CompilationError(self.fn.src, node, e) | ||||||
|         tt_device = _triton.driver.cu_device(device.index, False) |         tt_device = _triton.driver.cu_device(device.index, False) | ||||||
|         # Compile to machine code |         # Compile to machine code | ||||||
|         mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages) |         mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache) | ||||||
|         if shared_mem > tt_device.max_shared_memory(): |         if shared_mem > tt_device.max_shared_memory(): | ||||||
|             raise  OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory") |             raise  OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory") | ||||||
|         return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm) |         return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm) | ||||||
|  |  | ||||||
|     def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): |     def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): | ||||||
|         # device inference |         # device inference | ||||||
|         tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] |         tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] | ||||||
|         if len(tensor_idxs) == 0: |         if len(tensor_idxs) == 0: | ||||||
| @@ -573,7 +574,9 @@ class Kernel: | |||||||
|         if key not in cache: |         if key not in cache: | ||||||
|             # compile and cache configuration if necessary |             # compile and cache configuration if necessary | ||||||
|             cache[key] = self._compile( |             cache[key] = self._compile( | ||||||
|                 *wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta |                 *wargs, device=device, attributes=attributes,  | ||||||
|  |                 num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,  | ||||||
|  |                 constants=constants, **meta | ||||||
|             ) |             ) | ||||||
|         # pack arguments |         # pack arguments | ||||||
|         fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) |         fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) | ||||||
|   | |||||||
| @@ -183,7 +183,7 @@ class _softmax(torch.autograd.Function): | |||||||
|         } |         } | ||||||
|         grid = lambda opt: [spdims[0] * spdims[1] * block, M] |         grid = lambda opt: [spdims[0] * spdims[1] * block, M] | ||||||
|         _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ |         _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ | ||||||
|                        stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta) |                        stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta) | ||||||
|  |  | ||||||
|         # save to context |         # save to context | ||||||
|         ctx.mark_dirty(x) |         ctx.mark_dirty(x) | ||||||
| @@ -207,7 +207,7 @@ class _softmax(torch.autograd.Function): | |||||||
|         # run kernel |         # run kernel | ||||||
|         M = x.shape[0] |         M = x.shape[0] | ||||||
|         grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] |         grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] | ||||||
|         _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) |         _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block) | ||||||
|         return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None |         return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user