[TRITON][CODEGEN] Fixed flawed assert()
This commit is contained in:
@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream, b
|
||||
double total_time = 0;
|
||||
op();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-1){
|
||||
while(total_time*1e-9 < 1e-2){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
if(normalize)
|
||||
|
@@ -192,11 +192,11 @@ machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
||||
unsigned wpt_0 = layout->wpt(0);
|
||||
unsigned wpt_1 = layout->wpt(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
||||
// hmma warp tile size
|
||||
// mma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// hmma block tile size
|
||||
// mma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
|
@@ -36,7 +36,8 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
return;
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
assert(layout);
|
||||
if(!layout)
|
||||
return;
|
||||
if(alloc_->has_offset(layout)){
|
||||
unsigned offset = alloc_->offset(layout);
|
||||
res.push_back(interval_t(offset, offset + layout->get_size()));
|
||||
|
@@ -223,7 +223,7 @@ class kernel:
|
||||
defines.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [2, 4]
|
||||
opt.num_warps = [4]
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
|
@@ -12,8 +12,8 @@ int main() {
|
||||
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {true, false}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{ord, x[0], x[1], 512, 512, 512},
|
||||
config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 512, 512, 512},
|
||||
config_t{ord, x[0], x[1], 8192, 8192, 8192},
|
||||
// config_t{ord, x[0], x[1], 127008, 768, 576},
|
||||
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
|
||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||
@@ -36,7 +36,7 @@ int main() {
|
||||
for(const auto& c: configs){
|
||||
std::tie(ord, AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << c ;
|
||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
@@ -16,7 +16,7 @@ int main() {
|
||||
for(int nwarps: std::vector<int>{4})
|
||||
for(bool AT: std::array<bool, 2>{false, true})
|
||||
for(bool BT: std::array<bool, 2>{false, true}){
|
||||
configs.push_back(config_t{HALF, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
||||
configs.push_back(config_t{HALF, AT, BT, TM, TN, TK, TM, TN, TK, nwarps});
|
||||
}
|
||||
// test
|
||||
dtype_t dtype;
|
||||
|
Reference in New Issue
Block a user