[FRONTEND] Added compilation flag to force use of .nc cache modifier (#134)

in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED
/!\
This commit is contained in:
Philippe Tillet
2021-07-20 17:58:06 -04:00
committed by Philippe Tillet
parent 2824345065
commit 01276b5153
7 changed files with 27 additions and 18 deletions

View File

@@ -78,11 +78,11 @@ void init_triton_driver(py::module &&m) {
void init_triton_codegen(py::module &&m) {
m.def(
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) {
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) {
drv::module *mod;
drv::kernel *ker;
size_t shared_mem;
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem);
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem);
std::stringstream ss;
ir::print(ir, ss);
return std::make_tuple(mod, ker, shared_mem, ss.str());

View File

@@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor):
class Binary:
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm):
# cache ir asm
self.ir_asm = ir_asm
self.module = module
@@ -416,6 +416,7 @@ class Binary:
self.shared_mem = shared_mem
self.num_warps = num_warps
self.num_stages = num_stages
self.force_nc_cache = force_nc_cache
self.sass = None
def asm(self, mode):
@@ -524,7 +525,7 @@ class Kernel:
def __init__(self, fn):
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
# explicitly set device
torch.cuda.set_device(device.index)
# create IR module
@@ -546,12 +547,12 @@ class Kernel:
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
if shared_mem > tt_device.max_shared_memory():
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
@@ -573,7 +574,9 @@ class Kernel:
if key not in cache:
# compile and cache configuration if necessary
cache[key] = self._compile(
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
*wargs, device=device, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
constants=constants, **meta
)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])

View File

@@ -183,7 +183,7 @@ class _softmax(torch.autograd.Function):
}
grid = lambda opt: [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta)
# save to context
ctx.mark_dirty(x)
@@ -207,7 +207,7 @@ class _softmax(torch.autograd.Function):
# run kernel
M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None