[FRONTEND] Added compilation flag to force use of .nc
cache modifier (#134)
in DRAM loads. /!\ USE CAREFULLY - THIS CAN BREAK CORRECTNESS IF MISUSED /!\
This commit is contained in:
committed by
Philippe Tillet
parent
2824345065
commit
01276b5153
@@ -21,7 +21,7 @@ namespace codegen{
|
|||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// There should be a proper pass manager there!
|
// There should be a proper pass manager there!
|
||||||
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages,
|
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache,
|
||||||
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
|
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
|
||||||
|
|
||||||
|
|
||||||
|
@@ -122,7 +122,8 @@ public:
|
|||||||
analysis::allocation *alloc,
|
analysis::allocation *alloc,
|
||||||
analysis::swizzle *swizzle,
|
analysis::swizzle *swizzle,
|
||||||
target *tgt,
|
target *tgt,
|
||||||
unsigned num_warps);
|
unsigned num_warps,
|
||||||
|
bool force_nc_cache = false);
|
||||||
|
|
||||||
void visit_value(ir::value* v);
|
void visit_value(ir::value* v);
|
||||||
void visit_phi_node(ir::phi_node*);
|
void visit_phi_node(ir::phi_node*);
|
||||||
@@ -208,9 +209,11 @@ private:
|
|||||||
analysis::align *alignment_;
|
analysis::align *alignment_;
|
||||||
analysis::allocation *alloc_;
|
analysis::allocation *alloc_;
|
||||||
Value *shmem_;
|
Value *shmem_;
|
||||||
unsigned num_warps_;
|
|
||||||
std::set<ir::value*> seen_;
|
std::set<ir::value*> seen_;
|
||||||
|
|
||||||
|
unsigned num_warps_;
|
||||||
|
bool force_nc_cache_;
|
||||||
|
|
||||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||||
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||||
|
@@ -26,7 +26,7 @@ namespace codegen {
|
|||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// There should be a proper pass manager there!
|
// There should be a proper pass manager there!
|
||||||
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages,
|
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache,
|
||||||
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
|
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
|
||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
@@ -51,7 +51,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
|||||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||||
codegen::transform::prefetch prefetch_s(target.get());
|
codegen::transform::prefetch prefetch_s(target.get());
|
||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache);
|
||||||
// run passes
|
// run passes
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
|
@@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes,
|
|||||||
analysis::allocation *alloc,
|
analysis::allocation *alloc,
|
||||||
analysis::swizzle *swizzle,
|
analysis::swizzle *swizzle,
|
||||||
target *tgt,
|
target *tgt,
|
||||||
unsigned num_warps)
|
unsigned num_warps, bool force_nc_cache)
|
||||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
||||||
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
|
tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -626,6 +626,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
// -----
|
// -----
|
||||||
std::ostringstream asm_oss;
|
std::ostringstream asm_oss;
|
||||||
asm_oss << "@$" << n_words; // predicate
|
asm_oss << "@$" << n_words; // predicate
|
||||||
|
if(force_nc_cache_)
|
||||||
|
asm_oss << " ld.global.nc";
|
||||||
|
else
|
||||||
asm_oss << " ld.global.cg";
|
asm_oss << " ld.global.cg";
|
||||||
if(n_words > 1)
|
if(n_words > 1)
|
||||||
asm_oss << ".v" << n_words; // vector width
|
asm_oss << ".v" << n_words; // vector width
|
||||||
|
@@ -78,11 +78,11 @@ void init_triton_driver(py::module &&m) {
|
|||||||
|
|
||||||
void init_triton_codegen(py::module &&m) {
|
void init_triton_codegen(py::module &&m) {
|
||||||
m.def(
|
m.def(
|
||||||
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) {
|
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages, bool force_nc_cache) {
|
||||||
drv::module *mod;
|
drv::module *mod;
|
||||||
drv::kernel *ker;
|
drv::kernel *ker;
|
||||||
size_t shared_mem;
|
size_t shared_mem;
|
||||||
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem);
|
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, force_nc_cache, mod, ker, shared_mem);
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ir::print(ir, ss);
|
ir::print(ir, ss);
|
||||||
return std::make_tuple(mod, ker, shared_mem, ss.str());
|
return std::make_tuple(mod, ker, shared_mem, ss.str());
|
||||||
|
@@ -408,7 +408,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
|
|
||||||
class Binary:
|
class Binary:
|
||||||
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
|
def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm):
|
||||||
# cache ir asm
|
# cache ir asm
|
||||||
self.ir_asm = ir_asm
|
self.ir_asm = ir_asm
|
||||||
self.module = module
|
self.module = module
|
||||||
@@ -416,6 +416,7 @@ class Binary:
|
|||||||
self.shared_mem = shared_mem
|
self.shared_mem = shared_mem
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
self.num_stages = num_stages
|
self.num_stages = num_stages
|
||||||
|
self.force_nc_cache = force_nc_cache
|
||||||
self.sass = None
|
self.sass = None
|
||||||
|
|
||||||
def asm(self, mode):
|
def asm(self, mode):
|
||||||
@@ -524,7 +525,7 @@ class Kernel:
|
|||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
||||||
# explicitly set device
|
# explicitly set device
|
||||||
torch.cuda.set_device(device.index)
|
torch.cuda.set_device(device.index)
|
||||||
# create IR module
|
# create IR module
|
||||||
@@ -546,12 +547,12 @@ class Kernel:
|
|||||||
raise CompilationError(self.fn.src, node, e)
|
raise CompilationError(self.fn.src, node, e)
|
||||||
tt_device = _triton.driver.cu_device(device.index, False)
|
tt_device = _triton.driver.cu_device(device.index, False)
|
||||||
# Compile to machine code
|
# Compile to machine code
|
||||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
|
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
|
||||||
if shared_mem > tt_device.max_shared_memory():
|
if shared_mem > tt_device.max_shared_memory():
|
||||||
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
|
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
|
||||||
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
|
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||||
# device inference
|
# device inference
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
if len(tensor_idxs) == 0:
|
if len(tensor_idxs) == 0:
|
||||||
@@ -573,7 +574,9 @@ class Kernel:
|
|||||||
if key not in cache:
|
if key not in cache:
|
||||||
# compile and cache configuration if necessary
|
# compile and cache configuration if necessary
|
||||||
cache[key] = self._compile(
|
cache[key] = self._compile(
|
||||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
|
*wargs, device=device, attributes=attributes,
|
||||||
|
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||||
|
constants=constants, **meta
|
||||||
)
|
)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||||
|
@@ -183,7 +183,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
}
|
}
|
||||||
grid = lambda opt: [spdims[0] * spdims[1] * block, M]
|
grid = lambda opt: [spdims[0] * spdims[1] * block, M]
|
||||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
|
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\
|
||||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)
|
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, force_nc_cache=True, **meta)
|
||||||
|
|
||||||
# save to context
|
# save to context
|
||||||
ctx.mark_dirty(x)
|
ctx.mark_dirty(x)
|
||||||
@@ -207,7 +207,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
# run kernel
|
# run kernel
|
||||||
M = x.shape[0]
|
M = x.shape[0]
|
||||||
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
||||||
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
|
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block)
|
||||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user