[FRONTEND] Added compilation flag to force use of .nc
cache modifier (#134)
in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED /!\
This commit is contained in:
committed by
Philippe Tillet
parent
2824345065
commit
01276b5153
@@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
|
||||
def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm):
|
||||
# cache ir asm
|
||||
self.ir_asm = ir_asm
|
||||
self.module = module
|
||||
@@ -416,6 +416,7 @@ class Binary:
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.force_nc_cache = force_nc_cache
|
||||
self.sass = None
|
||||
|
||||
def asm(self, mode):
|
||||
@@ -524,7 +525,7 @@ class Kernel:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
||||
# explicitly set device
|
||||
torch.cuda.set_device(device.index)
|
||||
# create IR module
|
||||
@@ -546,12 +547,12 @@ class Kernel:
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
# Compile to machine code
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
|
||||
if shared_mem > tt_device.max_shared_memory():
|
||||
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
|
||||
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
|
||||
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
@@ -573,7 +574,9 @@ class Kernel:
|
||||
if key not in cache:
|
||||
# compile and cache configuration if necessary
|
||||
cache[key] = self._compile(
|
||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
|
||||
*wargs, device=device, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
constants=constants, **meta
|
||||
)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
|
Reference in New Issue
Block a user