[FRONTEND] Added compilation flag to force use of .nc cache modifier (#134)
				
					
				
			in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED /!\
This commit is contained in:
		
				
					committed by
					
						
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							2824345065
						
					
				
				
					commit
					01276b5153
				
			@@ -21,7 +21,7 @@ namespace codegen{
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// TODO:
 | 
					// TODO:
 | 
				
			||||||
// There should be a proper pass manager there!
 | 
					// There should be a proper pass manager there!
 | 
				
			||||||
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages,
 | 
					void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache,
 | 
				
			||||||
                            driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
 | 
					                            driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -122,7 +122,8 @@ public:
 | 
				
			|||||||
            analysis::allocation *alloc,
 | 
					            analysis::allocation *alloc,
 | 
				
			||||||
            analysis::swizzle *swizzle,
 | 
					            analysis::swizzle *swizzle,
 | 
				
			||||||
            target *tgt,
 | 
					            target *tgt,
 | 
				
			||||||
            unsigned num_warps);
 | 
					            unsigned num_warps,
 | 
				
			||||||
 | 
					            bool force_nc_cache = false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void visit_value(ir::value* v);
 | 
					  void visit_value(ir::value* v);
 | 
				
			||||||
  void visit_phi_node(ir::phi_node*);
 | 
					  void visit_phi_node(ir::phi_node*);
 | 
				
			||||||
@@ -208,9 +209,11 @@ private:
 | 
				
			|||||||
  analysis::align *alignment_;
 | 
					  analysis::align *alignment_;
 | 
				
			||||||
  analysis::allocation *alloc_;
 | 
					  analysis::allocation *alloc_;
 | 
				
			||||||
  Value *shmem_;
 | 
					  Value *shmem_;
 | 
				
			||||||
  unsigned num_warps_;
 | 
					 | 
				
			||||||
  std::set<ir::value*> seen_;
 | 
					  std::set<ir::value*> seen_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  unsigned num_warps_;
 | 
				
			||||||
 | 
					  bool force_nc_cache_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::map<analysis::data_layout*, Value*> offset_a_m_;
 | 
					  std::map<analysis::data_layout*, Value*> offset_a_m_;
 | 
				
			||||||
  std::map<analysis::data_layout*, Value*> offset_a_k_;
 | 
					  std::map<analysis::data_layout*, Value*> offset_a_k_;
 | 
				
			||||||
  std::map<analysis::data_layout*, Value*> offset_b_k_;
 | 
					  std::map<analysis::data_layout*, Value*> offset_b_k_;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -26,7 +26,7 @@ namespace codegen {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// TODO:
 | 
					// TODO:
 | 
				
			||||||
// There should be a proper pass manager there!
 | 
					// There should be a proper pass manager there!
 | 
				
			||||||
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages,
 | 
					void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache,
 | 
				
			||||||
                            driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
 | 
					                            driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
 | 
				
			||||||
  // generate llvm code
 | 
					  // generate llvm code
 | 
				
			||||||
  llvm::LLVMContext ctx;
 | 
					  llvm::LLVMContext ctx;
 | 
				
			||||||
@@ -51,7 +51,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
 | 
				
			|||||||
  codegen::transform::coalesce coalesce(&align, &layouts);
 | 
					  codegen::transform::coalesce coalesce(&align, &layouts);
 | 
				
			||||||
  codegen::transform::prefetch prefetch_s(target.get());
 | 
					  codegen::transform::prefetch prefetch_s(target.get());
 | 
				
			||||||
  codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
 | 
					  codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
 | 
				
			||||||
  codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
 | 
					  codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache);
 | 
				
			||||||
  // run passes
 | 
					  // run passes
 | 
				
			||||||
  dce.run(ir);
 | 
					  dce.run(ir);
 | 
				
			||||||
  peephole.run(ir);
 | 
					  peephole.run(ir);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes,
 | 
				
			|||||||
                    analysis::allocation *alloc,
 | 
					                    analysis::allocation *alloc,
 | 
				
			||||||
                    analysis::swizzle *swizzle,
 | 
					                    analysis::swizzle *swizzle,
 | 
				
			||||||
                    target *tgt,
 | 
					                    target *tgt,
 | 
				
			||||||
                    unsigned num_warps)
 | 
					                    unsigned num_warps, bool force_nc_cache)
 | 
				
			||||||
  : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
 | 
					  : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
 | 
				
			||||||
    tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
 | 
					    tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -626,7 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){
 | 
				
			|||||||
    // -----
 | 
					    // -----
 | 
				
			||||||
    std::ostringstream asm_oss;
 | 
					    std::ostringstream asm_oss;
 | 
				
			||||||
    asm_oss << "@$" << n_words; // predicate
 | 
					    asm_oss << "@$" << n_words; // predicate
 | 
				
			||||||
    asm_oss << " ld.global.cg";
 | 
					    if(force_nc_cache_)
 | 
				
			||||||
 | 
					      asm_oss << " ld.global.nc";
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      asm_oss << " ld.global.cg";
 | 
				
			||||||
    if(n_words > 1)
 | 
					    if(n_words > 1)
 | 
				
			||||||
      asm_oss << ".v" << n_words; // vector width
 | 
					      asm_oss << ".v" << n_words; // vector width
 | 
				
			||||||
    asm_oss << ".b" << width; // word size
 | 
					    asm_oss << ".b" << width; // word size
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -78,11 +78,11 @@ void init_triton_driver(py::module &&m) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
void init_triton_codegen(py::module &&m) {
 | 
					void init_triton_codegen(py::module &&m) {
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) {
 | 
					      "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) {
 | 
				
			||||||
        drv::module *mod;
 | 
					        drv::module *mod;
 | 
				
			||||||
        drv::kernel *ker;
 | 
					        drv::kernel *ker;
 | 
				
			||||||
        size_t shared_mem;
 | 
					        size_t shared_mem;
 | 
				
			||||||
        triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem);
 | 
					        triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem);
 | 
				
			||||||
        std::stringstream ss;
 | 
					        std::stringstream ss;
 | 
				
			||||||
        ir::print(ir, ss);
 | 
					        ir::print(ir, ss);
 | 
				
			||||||
        return std::make_tuple(mod, ker, shared_mem, ss.str());
 | 
					        return std::make_tuple(mod, ker, shared_mem, ss.str());
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Binary:
 | 
					class Binary:
 | 
				
			||||||
    def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
 | 
					    def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm):
 | 
				
			||||||
        # cache ir asm
 | 
					        # cache ir asm
 | 
				
			||||||
        self.ir_asm = ir_asm
 | 
					        self.ir_asm = ir_asm
 | 
				
			||||||
        self.module = module
 | 
					        self.module = module
 | 
				
			||||||
@@ -416,6 +416,7 @@ class Binary:
 | 
				
			|||||||
        self.shared_mem = shared_mem
 | 
					        self.shared_mem = shared_mem
 | 
				
			||||||
        self.num_warps = num_warps
 | 
					        self.num_warps = num_warps
 | 
				
			||||||
        self.num_stages = num_stages
 | 
					        self.num_stages = num_stages
 | 
				
			||||||
 | 
					        self.force_nc_cache = force_nc_cache
 | 
				
			||||||
        self.sass = None
 | 
					        self.sass = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def asm(self, mode):
 | 
					    def asm(self, mode):
 | 
				
			||||||
@@ -524,7 +525,7 @@ class Kernel:
 | 
				
			|||||||
    def __init__(self, fn):
 | 
					    def __init__(self, fn):
 | 
				
			||||||
        self.fn = fn
 | 
					        self.fn = fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
 | 
					    def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
 | 
				
			||||||
        # explicitly set device
 | 
					        # explicitly set device
 | 
				
			||||||
        torch.cuda.set_device(device.index)
 | 
					        torch.cuda.set_device(device.index)
 | 
				
			||||||
        # create IR module
 | 
					        # create IR module
 | 
				
			||||||
@@ -546,12 +547,12 @@ class Kernel:
 | 
				
			|||||||
            raise CompilationError(self.fn.src, node, e)
 | 
					            raise CompilationError(self.fn.src, node, e)
 | 
				
			||||||
        tt_device = _triton.driver.cu_device(device.index, False)
 | 
					        tt_device = _triton.driver.cu_device(device.index, False)
 | 
				
			||||||
        # Compile to machine code
 | 
					        # Compile to machine code
 | 
				
			||||||
        mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
 | 
					        mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
 | 
				
			||||||
        if shared_mem > tt_device.max_shared_memory():
 | 
					        if shared_mem > tt_device.max_shared_memory():
 | 
				
			||||||
            raise  OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
 | 
					            raise  OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
 | 
				
			||||||
        return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
 | 
					        return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
 | 
					    def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
 | 
				
			||||||
        # device inference
 | 
					        # device inference
 | 
				
			||||||
        tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
 | 
					        tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
 | 
				
			||||||
        if len(tensor_idxs) == 0:
 | 
					        if len(tensor_idxs) == 0:
 | 
				
			||||||
@@ -573,7 +574,9 @@ class Kernel:
 | 
				
			|||||||
        if key not in cache:
 | 
					        if key not in cache:
 | 
				
			||||||
            # compile and cache configuration if necessary
 | 
					            # compile and cache configuration if necessary
 | 
				
			||||||
            cache[key] = self._compile(
 | 
					            cache[key] = self._compile(
 | 
				
			||||||
                *wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
 | 
					                *wargs, device=device, attributes=attributes, 
 | 
				
			||||||
 | 
					                num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, 
 | 
				
			||||||
 | 
					                constants=constants, **meta
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        # pack arguments
 | 
					        # pack arguments
 | 
				
			||||||
        fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
 | 
					        fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -183,7 +183,7 @@ class _softmax(torch.autograd.Function):
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        grid = lambda opt: [spdims[0] * spdims[1] * block, M]
 | 
					        grid = lambda opt: [spdims[0] * spdims[1] * block, M]
 | 
				
			||||||
        _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
 | 
					        _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
 | 
				
			||||||
                       stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
 | 
					                       stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # save to context
 | 
					        # save to context
 | 
				
			||||||
        ctx.mark_dirty(x)
 | 
					        ctx.mark_dirty(x)
 | 
				
			||||||
@@ -207,7 +207,7 @@ class _softmax(torch.autograd.Function):
 | 
				
			|||||||
        # run kernel
 | 
					        # run kernel
 | 
				
			||||||
        M = x.shape[0]
 | 
					        M = x.shape[0]
 | 
				
			||||||
        grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
 | 
					        grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
 | 
				
			||||||
        _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
 | 
					        _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block)
 | 
				
			||||||
        return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
 | 
					        return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user