better fp16 support for dot
This commit is contained in:
@@ -26,8 +26,8 @@ struct perf_t {
|
|||||||
|
|
||||||
|
|
||||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||||
typedef half NumericT;
|
typedef float NumericT;
|
||||||
std::string ty = "half";
|
std::string ty = "float";
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
size_t dt_nbytes = sizeof(NumericT);
|
||||||
triton::driver::context* context = stream->context();
|
triton::driver::context* context = stream->context();
|
||||||
std::vector<NumericT> hc(M*N);
|
std::vector<NumericT> hc(M*N);
|
||||||
@@ -112,7 +112,11 @@ int main() {
|
|||||||
std::vector<config_t> configs = {
|
std::vector<config_t> configs = {
|
||||||
// {false, false, 8192, 512, 512},
|
// {false, false, 8192, 512, 512},
|
||||||
// {false, true, 8192, 8192, 8192}
|
// {false, true, 8192, 8192, 8192}
|
||||||
{false, true, 128, 128, 128},
|
// {false, true, 128, 128, 128},
|
||||||
|
// {false, false, 128, 128, 128},
|
||||||
|
// {true, false, 128, 128, 128},
|
||||||
|
{true, true, 128, 128, 128}
|
||||||
|
|
||||||
// {false, true, 32768, 256, 512}
|
// {false, true, 32768, 256, 512}
|
||||||
// {true, false, 8192, 512, 512},
|
// {true, false, 8192, 512, 512},
|
||||||
// {true, true, 8192, 512, 512}
|
// {true, true, 8192, 512, 512}
|
||||||
|
@@ -62,11 +62,11 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
|||||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
params_t params = heuristics();
|
// params_t params = heuristics();
|
||||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||||
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
|
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
|
||||||
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
|
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
|
||||||
// params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT
|
params_t params = {4, 16, 4, 2, 16, 4, 8, 2, 2, 8, 2, 32, 8, 1}; // TT
|
||||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||||
}
|
}
|
||||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||||
|
Reference in New Issue
Block a user