[triton/python/conv]: Added cache for compiled kernels
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
@@ -158,8 +158,8 @@ int main() {
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// initialize constant memory
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)kernel->module())->symbol("delta");
|
||||
triton::driver::buffer* masks = ((triton::driver::cu_module*)kernel->module())->symbol("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
stream->synchronize();
|
||||
|
Reference in New Issue
Block a user