[FRONTEND] Added compilation flag to force use of .nc cache modifier (#134)

in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED
/!\
This commit is contained in:
Philippe Tillet
2021-07-20 17:58:06 -04:00
committed by Philippe Tillet
parent 2824345065
commit 01276b5153
7 changed files with 27 additions and 18 deletions

View File

@@ -21,7 +21,7 @@ namespace codegen{
// TODO: // TODO:
// There should be a proper pass manager there! // There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache,
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem); driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);

View File

@@ -122,7 +122,8 @@ public:
analysis::allocation *alloc, analysis::allocation *alloc,
analysis::swizzle *swizzle, analysis::swizzle *swizzle,
target *tgt, target *tgt,
unsigned num_warps); unsigned num_warps,
bool force_nc_cache = false);
void visit_value(ir::value* v); void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*); void visit_phi_node(ir::phi_node*);
@@ -208,9 +209,11 @@ private:
analysis::align *alignment_; analysis::align *alignment_;
analysis::allocation *alloc_; analysis::allocation *alloc_;
Value *shmem_; Value *shmem_;
unsigned num_warps_;
std::set<ir::value*> seen_; std::set<ir::value*> seen_;
unsigned num_warps_;
bool force_nc_cache_;
std::map<analysis::data_layout*, Value*> offset_a_m_; std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_; std::map<analysis::data_layout*, Value*> offset_a_k_;
std::map<analysis::data_layout*, Value*> offset_b_k_; std::map<analysis::data_layout*, Value*> offset_b_k_;

View File

@@ -26,7 +26,7 @@ namespace codegen {
// TODO: // TODO:
// There should be a proper pass manager there! // There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache,
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) { driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
// generate llvm code // generate llvm code
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
@@ -51,7 +51,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target.get()); codegen::transform::prefetch prefetch_s(target.get());
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get()); codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache);
// run passes // run passes
dce.run(ir); dce.run(ir);
peephole.run(ir); peephole.run(ir);

View File

@@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes,
analysis::allocation *alloc, analysis::allocation *alloc,
analysis::swizzle *swizzle, analysis::swizzle *swizzle,
target *tgt, target *tgt,
unsigned num_warps) unsigned num_warps, bool force_nc_cache)
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) {
} }
@@ -626,6 +626,9 @@ void generator::visit_load_inst(ir::load_inst* x){
// ----- // -----
std::ostringstream asm_oss; std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate asm_oss << "@$" << n_words; // predicate
if(force_nc_cache_)
asm_oss << " ld.global.nc";
else
asm_oss << " ld.global.cg"; asm_oss << " ld.global.cg";
if(n_words > 1) if(n_words > 1)
asm_oss << ".v" << n_words; // vector width asm_oss << ".v" << n_words; // vector width

View File

@@ -78,11 +78,11 @@ void init_triton_driver(py::module &&m) {
void init_triton_codegen(py::module &&m) { void init_triton_codegen(py::module &&m) {
m.def( m.def(
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) { "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) {
drv::module *mod; drv::module *mod;
drv::kernel *ker; drv::kernel *ker;
size_t shared_mem; size_t shared_mem;
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem); triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem);
std::stringstream ss; std::stringstream ss;
ir::print(ir, ss); ir::print(ir, ss);
return std::make_tuple(mod, ker, shared_mem, ss.str()); return std::make_tuple(mod, ker, shared_mem, ss.str());

View File

@@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor):
class Binary: class Binary:
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm): def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm):
# cache ir asm # cache ir asm
self.ir_asm = ir_asm self.ir_asm = ir_asm
self.module = module self.module = module
@@ -416,6 +416,7 @@ class Binary:
self.shared_mem = shared_mem self.shared_mem = shared_mem
self.num_warps = num_warps self.num_warps = num_warps
self.num_stages = num_stages self.num_stages = num_stages
self.force_nc_cache = force_nc_cache
self.sass = None self.sass = None
def asm(self, mode): def asm(self, mode):
@@ -524,7 +525,7 @@ class Kernel:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
# explicitly set device # explicitly set device
torch.cuda.set_device(device.index) torch.cuda.set_device(device.index)
# create IR module # create IR module
@@ -546,12 +547,12 @@ class Kernel:
raise CompilationError(self.fn.src, node, e) raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False) tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code # Compile to machine code
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages) mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
if shared_mem > tt_device.max_shared_memory(): if shared_mem > tt_device.max_shared_memory():
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory") raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm) return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
# device inference # device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0: if len(tensor_idxs) == 0:
@@ -573,7 +574,9 @@ class Kernel:
if key not in cache: if key not in cache:
# compile and cache configuration if necessary # compile and cache configuration if necessary
cache[key] = self._compile( cache[key] = self._compile(
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta *wargs, device=device, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
constants=constants, **meta
) )
# pack arguments # pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])

View File

@@ -183,7 +183,7 @@ class _softmax(torch.autograd.Function):
} }
grid = lambda opt: [spdims[0] * spdims[1] * block, M] grid = lambda opt: [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\ _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta) stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta)
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
@@ -207,7 +207,7 @@ class _softmax(torch.autograd.Function):
# run kernel # run kernel
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None