[examples] some cleaning
This commit is contained in:
@@ -64,6 +64,11 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
|
|||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
// initialize JIT on default device
|
||||||
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
triton::jit jit(context);
|
||||||
|
|
||||||
|
// add module
|
||||||
std::vector<unsigned> params = {
|
std::vector<unsigned> params = {
|
||||||
// a0
|
// a0
|
||||||
2, 8, 1, 16,
|
2, 8, 1, 16,
|
||||||
@@ -76,13 +81,7 @@ int main() {
|
|||||||
// b1
|
// b1
|
||||||
1, 8, 1
|
1, 8, 1
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
|
||||||
triton::jit jit(context);
|
|
||||||
jit.add_module(src, params);
|
jit.add_module(src, params);
|
||||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
|
||||||
|
|
||||||
size_t M = 128, N = 128, K = 128;
|
size_t M = 128, N = 128, K = 128;
|
||||||
size_t bound = 8;
|
size_t bound = 8;
|
||||||
@@ -104,6 +103,7 @@ int main() {
|
|||||||
stream.write(da, true, 0, ha);
|
stream.write(da, true, 0, ha);
|
||||||
stream.write(db, true, 0, hb);
|
stream.write(db, true, 0, hb);
|
||||||
stream.write(dc, true, 0, hc);
|
stream.write(dc, true, 0, hc);
|
||||||
|
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||||
kernel.setArg(0, da);
|
kernel.setArg(0, da);
|
||||||
kernel.setArg(1, db);
|
kernel.setArg(1, db);
|
||||||
kernel.setArg(2, dc);
|
kernel.setArg(2, dc);
|
||||||
@@ -111,6 +111,7 @@ int main() {
|
|||||||
kernel.setArg(4, N);
|
kernel.setArg(4, N);
|
||||||
kernel.setArg(5, K);
|
kernel.setArg(5, K);
|
||||||
kernel.setArg(6, bound);
|
kernel.setArg(6, bound);
|
||||||
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
unsigned TM = info.global_range_size[0];
|
unsigned TM = info.global_range_size[0];
|
||||||
unsigned TN = info.global_range_size[1];
|
unsigned TN = info.global_range_size[1];
|
||||||
unsigned nthreads = info.num_threads;
|
unsigned nthreads = info.num_threads;
|
||||||
|
Reference in New Issue
Block a user