[CI] Some fixes for the build (#451)
This commit is contained in:
@@ -150,7 +150,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
for(int i = 0; i < len; i++){
|
for(int i = 0; i < len; i++){
|
||||||
cache_key += "_";
|
cache_key += "_";
|
||||||
py::int_ py_i = py::int_(i);
|
py::int_ py_i = py::int_(i);
|
||||||
bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end();
|
bool specialize = !do_not_specialize.contains(py_i);
|
||||||
py::object arg = args[i];
|
py::object arg = args[i];
|
||||||
auto arg_ptr = arg.ptr();
|
auto arg_ptr = arg.ptr();
|
||||||
|
|
||||||
|
@@ -169,8 +169,6 @@ import triton.language as tl
|
|||||||
],
|
],
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
)
|
)
|
||||||
# %
|
|
||||||
# We can now define our kernel as normal, using all the techniques presented above
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def matmul_kernel(
|
def matmul_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
|
Reference in New Issue
Block a user