[FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch: 1. Use static triton types (e.g., tl.int32) instead of creating new types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same thing. 2. The name of an extern inst should be empty but not the symbol name of the inst. TTIR generator will assign names automatically. Otherwise, we have the same variable name when there are multiple same extern insts. Before the PR: ```bash __nv_exp = extern_elementwise f64<1024> %11; __nv_exp = extern_elementwise f64<1024> %11; ``` After the PR: ```bash %12 = extern_elementwise f64<1024> %11; %13 = extern_elementwise f64<1024> %11; ```
This commit is contained in:
@@ -1119,7 +1119,8 @@ class extern_elementwise_inst : public instruction {
|
|||||||
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
|
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
|
||||||
type *dst_ty, const std::string &lib_name,
|
type *dst_ty, const std::string &lib_name,
|
||||||
const std::string &extern_lib_path,
|
const std::string &extern_lib_path,
|
||||||
const std::string &symbol_name, instruction *next);
|
const std::string &symbol_name,
|
||||||
|
const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return "extern_elementwise"; }
|
std::string repr_impl() const { return "extern_elementwise"; }
|
||||||
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
|
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
|
||||||
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
|
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
|
||||||
@@ -1128,14 +1129,17 @@ class extern_elementwise_inst : public instruction {
|
|||||||
static extern_elementwise_inst *create(
|
static extern_elementwise_inst *create(
|
||||||
context &ctx, const std::vector<value *> &args, type *dst_ty,
|
context &ctx, const std::vector<value *> &args, type *dst_ty,
|
||||||
const std::string &lib_name = "", const std::string &lib_path = "",
|
const std::string &lib_name = "", const std::string &lib_path = "",
|
||||||
const std::string &symbol_name = "", instruction *next = nullptr);
|
const std::string &symbol_name = "", const std::string &name = "",
|
||||||
|
instruction *next = nullptr);
|
||||||
|
|
||||||
const std::string &get_lib_name() const { return lib_name_; }
|
const std::string &get_lib_name() const { return lib_name_; }
|
||||||
const std::string &get_lib_path() const { return lib_path_; }
|
const std::string &get_lib_path() const { return lib_path_; }
|
||||||
|
const std::string &get_symbol_name() const { return symbol_name_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string lib_name_ = "";
|
std::string lib_name_;
|
||||||
std::string lib_path_ = "";
|
std::string lib_path_;
|
||||||
|
std::string symbol_name_;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -3531,7 +3531,7 @@ void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) {
|
|||||||
FunctionType *FT =
|
FunctionType *FT =
|
||||||
FunctionType::get(ret_type, std::move(operand_types), false);
|
FunctionType::get(ret_type, std::move(operand_types), false);
|
||||||
Function *F = llvm::cast<llvm::Function>(
|
Function *F = llvm::cast<llvm::Function>(
|
||||||
mod_->getOrInsertFunction(i->get_name(), FT).getCallee());
|
mod_->getOrInsertFunction(i->get_symbol_name(), FT).getCallee());
|
||||||
for (auto idx : idxs_.at(i)) {
|
for (auto idx : idxs_.at(i)) {
|
||||||
std::vector<llvm::Value *> args;
|
std::vector<llvm::Value *> args;
|
||||||
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
||||||
|
@@ -1007,11 +1007,11 @@ globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name
|
|||||||
extern_elementwise_inst::extern_elementwise_inst(
|
extern_elementwise_inst::extern_elementwise_inst(
|
||||||
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
||||||
const std::string &lib_name, const std::string &lib_path,
|
const std::string &lib_name, const std::string &lib_path,
|
||||||
const std::string &symbol_name, instruction *next)
|
const std::string &symbol_name, const std::string &name, instruction *next)
|
||||||
: instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), symbol_name,
|
: instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), name, next),
|
||||||
next),
|
|
||||||
lib_name_(lib_name),
|
lib_name_(lib_name),
|
||||||
lib_path_(lib_path) {
|
lib_path_(lib_path),
|
||||||
|
symbol_name_(symbol_name) {
|
||||||
for (size_t i = 0; i < args.size(); i++) {
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
set_operand(i, args[i]);
|
set_operand(i, args[i]);
|
||||||
}
|
}
|
||||||
@@ -1020,9 +1020,10 @@ extern_elementwise_inst::extern_elementwise_inst(
|
|||||||
extern_elementwise_inst *extern_elementwise_inst::create(
|
extern_elementwise_inst *extern_elementwise_inst::create(
|
||||||
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
||||||
const std::string &lib_name, const std::string &lib_path,
|
const std::string &lib_name, const std::string &lib_path,
|
||||||
const std::string &symbol_name, instruction *next) {
|
const std::string &symbol_name, const std::string &name,
|
||||||
|
instruction *next) {
|
||||||
return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path,
|
return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path,
|
||||||
symbol_name, next);
|
symbol_name, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// clock
|
// clock
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -51,9 +51,9 @@ def convert_type(type_str):
|
|||||||
elif type_str == "u64":
|
elif type_str == "u64":
|
||||||
return "uint64"
|
return "uint64"
|
||||||
elif type_str == "float":
|
elif type_str == "float":
|
||||||
return "fp32"
|
return "float32"
|
||||||
elif type_str == "double":
|
elif type_str == "double":
|
||||||
return "fp64"
|
return "float64"
|
||||||
else:
|
else:
|
||||||
# ignore other types, such as pointer types
|
# ignore other types, such as pointer types
|
||||||
return None
|
return None
|
||||||
@@ -268,8 +268,8 @@ class Libdevice(ExternLibrary):
|
|||||||
for symbol in symbols:
|
for symbol in symbols:
|
||||||
arg_type_symbol_dict_str += "("
|
arg_type_symbol_dict_str += "("
|
||||||
for arg_type in symbol.arg_types:
|
for arg_type in symbol.arg_types:
|
||||||
arg_type_symbol_dict_str += f"core.dtype(\"{arg_type}\"),"
|
arg_type_symbol_dict_str += f"core.{arg_type},"
|
||||||
ret_type = f"core.dtype(\"{symbol.ret_type}\")"
|
ret_type = f"core.{symbol.ret_type}"
|
||||||
arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
|
arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
|
||||||
arg_type_symbol_dict_str += "}"
|
arg_type_symbol_dict_str += "}"
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user