[FRONTEND] Alignment fix-up (#428)
This commit is contained in:
5659
include/triton/external/CUDA/cuda.h
vendored
5659
include/triton/external/CUDA/cuda.h
vendored
File diff suppressed because it is too large
Load Diff
@@ -101,12 +101,12 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string pow2_divisor(long N){
|
long pow2_divisor(long N){
|
||||||
if(N % 16 == 0) return "16";
|
if(N % 16 == 0) return 16;
|
||||||
if(N % 8 == 0) return "8";
|
if(N % 8 == 0) return 8;
|
||||||
if(N % 4 == 0) return "4";
|
if(N % 4 == 0) return 4;
|
||||||
if(N % 2 == 0) return "2";
|
if(N % 2 == 0) return 2;
|
||||||
return "1";
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||||
@@ -127,6 +127,14 @@ std::string dtype_cache_key_part(const py::object& dtype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t get_pointer_range_size(uint64_t addr){
|
||||||
|
if(addr == 0)
|
||||||
|
return 0;
|
||||||
|
size_t size;
|
||||||
|
drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr);
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||||
@@ -187,7 +195,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
continue;
|
continue;
|
||||||
// values divisible by small powers of 2 are specialized
|
// values divisible by small powers of 2 are specialized
|
||||||
cache_key += "[multipleof(";
|
cache_key += "[multipleof(";
|
||||||
cache_key += pow2_divisor(value);
|
cache_key += std::to_string(pow2_divisor(value));
|
||||||
cache_key += ")]";
|
cache_key += ")]";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -213,12 +221,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
py::object data_ptr = arg.attr("data_ptr")();
|
py::object data_ptr = arg.attr("data_ptr")();
|
||||||
long value = data_ptr.cast<long>();
|
long value = data_ptr.cast<long>();
|
||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
|
// copy param
|
||||||
std::memcpy(params_ptr, &value, 8);
|
std::memcpy(params_ptr, &value, 8);
|
||||||
params_ptr += 8;
|
params_ptr += 8;
|
||||||
|
// udpate cache key
|
||||||
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
||||||
cache_key += "*";
|
cache_key += "*";
|
||||||
cache_key += "[multipleof(";
|
cache_key += "[multipleof(";
|
||||||
cache_key += pow2_divisor(value);
|
size_t range_size = get_pointer_range_size(value);
|
||||||
|
cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size)));
|
||||||
cache_key += ")]";
|
cache_key += ")]";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -268,6 +279,10 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// get range size for the given pointer
|
||||||
|
m.def("get_pointer_range_size", &get_pointer_range_size);
|
||||||
|
|
||||||
|
|
||||||
// cache key
|
// cache key
|
||||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||||
|
@@ -674,9 +674,17 @@ class Kernel:
|
|||||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
# attributes
|
# attributes
|
||||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
attributes = dict()
|
||||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
|
for i, arg in enumerate(wargs):
|
||||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
if i in self.fn.do_not_specialize:
|
||||||
|
continue
|
||||||
|
if isinstance(arg, int):
|
||||||
|
attributes[i] = Kernel.pow2_divisor(arg)
|
||||||
|
elif i in tensor_idxs:
|
||||||
|
addr = arg.data_ptr()
|
||||||
|
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||||
|
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||||
|
Kernel.pow2_divisor(range_size))
|
||||||
|
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||||
|
@@ -165,6 +165,7 @@ class block:
|
|||||||
self.numel = 1
|
self.numel = 1
|
||||||
for s in self.shape:
|
for s in self.shape:
|
||||||
self.numel *= s
|
self.numel *= s
|
||||||
|
self.numel = constexpr(self.numel)
|
||||||
# Data-type wrapper
|
# Data-type wrapper
|
||||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||||
|
|
||||||
@@ -873,7 +874,7 @@ def ravel(x):
|
|||||||
:param x: the input block
|
:param x: the input block
|
||||||
:type x: Block
|
:type x: Block
|
||||||
"""
|
"""
|
||||||
return triton.language.reshape(x, [x.type.numel])
|
return triton.language.reshape(x, [x.numel])
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
Reference in New Issue
Block a user