Use mlir::Block to replace MlirBlock
This commit is contained in:
@@ -668,12 +668,12 @@ void init_triton_ir(py::module &&m) {
|
|||||||
;
|
;
|
||||||
|
|
||||||
py::class_<MlirOperation>(m, "operation")
|
py::class_<MlirOperation>(m, "operation")
|
||||||
.def("add_entry_block", [](MlirOperation &self) -> MlirBlock {
|
.def("add_entry_block", [](MlirOperation &self) -> mlir::Block {
|
||||||
if (auto info = unwrap(self)->getRegisteredInfo()) {
|
if (auto info = unwrap(self)->getRegisteredInfo()) {
|
||||||
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
||||||
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
|
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
|
||||||
mlir::Block *entry = funcOp.addEntryBlock();
|
mlir::Block *entry = funcOp.addEntryBlock();
|
||||||
return wrap(entry);
|
return *entry;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Only FuncOp can call add_entry_block");
|
throw std::runtime_error("Only FuncOp can call add_entry_block");
|
||||||
} else
|
} else
|
||||||
@@ -684,12 +684,14 @@ void init_triton_ir(py::module &&m) {
|
|||||||
})
|
})
|
||||||
;
|
;
|
||||||
|
|
||||||
py::class_<MlirValue>(m, "value")
|
py::class_<mlir::Value>(m, "value")
|
||||||
|
;
|
||||||
|
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
|
||||||
;
|
;
|
||||||
|
|
||||||
py::class_<MlirBlock>(m, "block")
|
py::class_<mlir::Block>(m, "block")
|
||||||
.def("arg", [](MlirBlock &self, int index) -> MlirValue {
|
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
|
||||||
return wrap(unwrap(self)->getArgument(index));
|
return self.getArgument(index);
|
||||||
})
|
})
|
||||||
;
|
;
|
||||||
|
|
||||||
@@ -741,12 +743,16 @@ void init_triton_ir(py::module &&m) {
|
|||||||
// .def("br", &ir::builder::create_br, ret::reference)
|
// .def("br", &ir::builder::create_br, ret::reference)
|
||||||
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||||
// .def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
// .def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||||
// // insertion block/point, insert points are represented as (*bb, *instr)
|
// insertion block/point
|
||||||
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, MlirBlock &block) -> void{
|
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
|
||||||
self.setInsertionPointToStart(unwrap(block));
|
self.setInsertionPointToStart(&block);
|
||||||
})
|
})
|
||||||
// .def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
.def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) {
|
||||||
// .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
self.setInsertionPointToEnd(&block);
|
||||||
|
})
|
||||||
|
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block & {
|
||||||
|
return *self.getInsertionBlock();
|
||||||
|
}, ret::reference)
|
||||||
// .def("get_insert_point", [](ir::builder *self) {
|
// .def("get_insert_point", [](ir::builder *self) {
|
||||||
// ir::basic_block *bb = self->get_insert_block();
|
// ir::basic_block *bb = self->get_insert_block();
|
||||||
// ir::basic_block::iterator it = self->get_insert_point();
|
// ir::basic_block::iterator it = self->get_insert_point();
|
||||||
@@ -768,11 +774,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
// Use arith.ConstantOp to create constants
|
// Use arith.ConstantOp to create constants
|
||||||
// // Constants
|
// // Constants
|
||||||
// .def("get_int1", &ir::builder::get_int1, ret::reference)
|
// .def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||||
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> MlirValue {
|
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return wrap(mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||||
loc, v, self.getI32Type()
|
loc, v, self.getI32Type()
|
||||||
)));
|
));
|
||||||
})
|
})
|
||||||
// .def("get_uint32", &ir::builder::get_int32, ret::reference)
|
// .def("get_uint32", &ir::builder::get_int32, ret::reference)
|
||||||
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
||||||
@@ -818,9 +824,15 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
|
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||||
return wrap(self.getF64Type());
|
return wrap(self.getF64Type());
|
||||||
})
|
})
|
||||||
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType {
|
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType {
|
||||||
return wrap(
|
return wrap(
|
||||||
mlir::triton::PointerType::get(unwrap(type))
|
mlir::triton::PointerType::get(unwrap(type), addrSpace)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType,
|
||||||
|
std::vector<int64_t> &shape) -> MlirType {
|
||||||
|
return wrap(
|
||||||
|
mlir::RankedTensorType::get(shape, unwrap(elementType))
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
.def("get_function_ty", [](mlir::OpBuilder &self,
|
.def("get_function_ty", [](mlir::OpBuilder &self,
|
||||||
|
@@ -30,7 +30,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.prototype = prototype
|
self.prototype = prototype
|
||||||
self.gscope = gscope
|
self.gscope = gscope
|
||||||
self.lscope = dict()
|
self.lscope = dict()
|
||||||
self.is_arg_lscope = dict() # name => is_arg: {str: bool}
|
|
||||||
self.attributes = attributes
|
self.attributes = attributes
|
||||||
self.constants = constants
|
self.constants = constants
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
@@ -69,33 +68,32 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ret = self.builtins[name]
|
ret = self.builtins[name]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'{name} is not defined')
|
raise ValueError(f'{name} is not defined')
|
||||||
if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]:
|
if self.is_triton_tensor(ret):
|
||||||
return self._get_tensor(name)
|
return self._get_tensor(name)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def set_value(self, name: str,
|
def set_value(self, name: str,
|
||||||
value: Union[triton.language.tensor, triton.language.constexpr],
|
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
|
||||||
is_arg: bool = False) -> None:
|
|
||||||
''' This function:
|
''' This function:
|
||||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||||
1. record local defined name (FIXME: should consider control flow)
|
1. record local defined name (FIXME: should consider control flow)
|
||||||
2. store tensor in self.lvalue
|
2. store tensor in self.lvalue
|
||||||
'''
|
'''
|
||||||
self.lscope[name] = value
|
self.lscope[name] = value
|
||||||
# if this value is an argument, we don't need to create phis for it
|
if isinstance(value, triton.language.tensor):
|
||||||
self.is_arg_lscope[name] = is_arg
|
self._set_value(name, self.builder.get_insertion_block(), value)
|
||||||
if isinstance(value, triton.language.tensor) and not is_arg:
|
|
||||||
self._set_value(name, self.builder.get_insert_block(), value)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# SSA-construction
|
# SSA-construction
|
||||||
#
|
#
|
||||||
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
|
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
|
||||||
if not bb:
|
if not bb:
|
||||||
bb = self.builder.get_insert_block()
|
bb = self.builder.get_insertion_block()
|
||||||
# local value numbering
|
# local value numbering
|
||||||
if (name, bb) in self.lvalues:
|
if (name, bb) in self.lvalues:
|
||||||
return self.lvalues[(name, bb)]
|
return self.lvalues[(name, bb)]
|
||||||
|
print(self.lvalues)
|
||||||
|
assert False, f'Cannot find {name} in {bb}'
|
||||||
# global value numbering
|
# global value numbering
|
||||||
saved_insert_point = self.builder.get_insert_point()
|
saved_insert_point = self.builder.get_insert_point()
|
||||||
result = self._get_tensor_recursive(name, bb)
|
result = self._get_tensor_recursive(name, bb)
|
||||||
@@ -115,8 +113,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
elif len(preds) == 1:
|
elif len(preds) == 1:
|
||||||
# one predecessor: no phi needed, try get value from pred
|
# one predecessor: no phi needed, try get value from pred
|
||||||
result = self._get_tensor(name, preds[0])
|
result = self._get_tensor(name, preds[0])
|
||||||
|
elif len(preds) == 0:
|
||||||
|
result = self._get_tensor(name, None)
|
||||||
else: # multiple preds
|
else: # multiple preds
|
||||||
assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)'
|
|
||||||
phi = self._make_phi(type, len(preds), bb)
|
phi = self._make_phi(type, len(preds), bb)
|
||||||
self._set_value(name, bb, phi)
|
self._set_value(name, bb, phi)
|
||||||
result = self._add_phi_operands(name, phi)
|
result = self._add_phi_operands(name, phi)
|
||||||
@@ -148,8 +147,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
|
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
|
||||||
self.lvalues[(name, bb)] = value
|
self.lvalues[(name, bb)] = value
|
||||||
# TODO: why we need this?
|
# # TODO: why we need this?
|
||||||
self.module.set_instr_metadata(name, value.handle)
|
# self.module.set_instr_metadata(name, value.handle)
|
||||||
|
|
||||||
def _seal_block(self, bb: _triton.ir.basic_block):
|
def _seal_block(self, bb: _triton.ir.basic_block):
|
||||||
# complete all incomplete phis
|
# complete all incomplete phis
|
||||||
@@ -220,7 +219,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if inline:
|
if inline:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder))
|
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
|
||||||
|
self.module.push_back(fn)
|
||||||
arg_values = []
|
arg_values = []
|
||||||
idx = 0
|
idx = 0
|
||||||
for i, arg_name in enumerate(arg_names):
|
for i, arg_name in enumerate(arg_names):
|
||||||
@@ -230,25 +230,27 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cst = triton.language.constexpr(self.constants[i])
|
cst = triton.language.constexpr(self.constants[i])
|
||||||
arg_values.append(cst)
|
arg_values.append(cst)
|
||||||
else:
|
else:
|
||||||
if i in self.attributes:
|
pass
|
||||||
is_ptr = fn.args[idx].type.is_ptr()
|
# TODO: ...
|
||||||
attr = 'aligned' if is_ptr else 'multiple_of'
|
# if i in self.attributes:
|
||||||
attr = getattr(_triton.ir.attribute_kind, attr)
|
# is_ptr = fn.args[idx].type.is_ptr()
|
||||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
# attr = 'aligned' if is_ptr else 'multiple_of'
|
||||||
fn.add_attr(idx + 1, attr)
|
# attr = getattr(_triton.ir.attribute_kind, attr)
|
||||||
fn.args[idx].name = arg_name
|
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||||
arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
# fn.add_attr(idx + 1, attr)
|
||||||
idx += 1
|
# fn.args[idx].name = arg_name
|
||||||
|
# arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
||||||
|
# idx += 1
|
||||||
|
|
||||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||||
self.set_value(arg_name, arg_value, is_arg=True)
|
self.set_value(arg_name, arg_value)
|
||||||
if inline:
|
if inline:
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
return self.last_ret
|
return self.last_ret
|
||||||
else:
|
else:
|
||||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
entry = fn.add_entry_block()
|
||||||
self._seal_block(entry)
|
self._seal_block(entry)
|
||||||
self.builder.set_insert_block(entry)
|
self.builder.set_insertion_point_to_start(entry)
|
||||||
# visit function body
|
# visit function body
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
# finalize function
|
# finalize function
|
||||||
@@ -358,7 +360,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cond = self.visit(node.test)
|
cond = self.visit(node.test)
|
||||||
if isinstance(cond, triton.language.tensor):
|
if isinstance(cond, triton.language.tensor):
|
||||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||||
current_bb = self.builder.get_insert_block()
|
current_bb = self.builder.get_insertion_block()
|
||||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||||
@@ -445,7 +447,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return getattr(op, fn)()
|
return getattr(op, fn)()
|
||||||
|
|
||||||
def visit_While(self, node):
|
def visit_While(self, node):
|
||||||
current_bb = self.builder.get_insert_block()
|
current_bb = self.builder.get_insertion_block()
|
||||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||||
|
|
||||||
@@ -457,7 +459,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.builder.set_insert_block(loop_bb)
|
self.builder.set_insert_block(loop_bb)
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
continue_fn()
|
continue_fn()
|
||||||
stop_bb = self.builder.get_insert_block()
|
stop_bb = self.builder.get_insertion_block()
|
||||||
self._seal_block(stop_bb)
|
self._seal_block(stop_bb)
|
||||||
self._seal_block(loop_bb)
|
self._seal_block(loop_bb)
|
||||||
self._seal_block(next_bb)
|
self._seal_block(next_bb)
|
||||||
@@ -512,7 +514,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# cond_node = neg_cond_node
|
# cond_node = neg_cond_node
|
||||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||||
# code generation
|
# code generation
|
||||||
current_bb = self.builder.get_insert_block()
|
current_bb = self.builder.get_insertion_block()
|
||||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||||
|
|
||||||
@@ -528,7 +530,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
# TODO: handle case where body breaks control flow
|
# TODO: handle case where body breaks control flow
|
||||||
continue_fn()
|
continue_fn()
|
||||||
stop_bb = self.builder.get_insert_block()
|
stop_bb = self.builder.get_insertion_block()
|
||||||
self._seal_block(stop_bb)
|
self._seal_block(stop_bb)
|
||||||
self._seal_block(loop_bb)
|
self._seal_block(loop_bb)
|
||||||
self._seal_block(next_bb)
|
self._seal_block(next_bb)
|
||||||
@@ -845,10 +847,11 @@ class Kernel:
|
|||||||
|
|
||||||
# create IR module
|
# create IR module
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
|
context.load_triton()
|
||||||
# get just-in-time proto-type of kernel
|
# get just-in-time proto-type of kernel
|
||||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||||
ret_type = triton.language.void
|
ret_type = triton.language.void
|
||||||
prototype = triton.language.function_type(ret_type, arg_types)
|
prototype = triton.language.function_type([ret_type], arg_types)
|
||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
# export symbols visible from self into code-generator object
|
# export symbols visible from self into code-generator object
|
||||||
gscope = self.__globals__
|
gscope = self.__globals__
|
||||||
@@ -1179,10 +1182,11 @@ class JITFunction:
|
|||||||
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages):
|
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages):
|
||||||
# create IR module
|
# create IR module
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
|
context.load_triton()
|
||||||
# get just-in-time proto-type of kernel
|
# get just-in-time proto-type of kernel
|
||||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||||
ret_type = triton.language.void
|
ret_type = triton.language.void
|
||||||
prototype = triton.language.function_type(ret_type, arg_types)
|
prototype = triton.language.function_type([ret_type], arg_types)
|
||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
# export symbols visible from self into code-generator object
|
# export symbols visible from self into code-generator object
|
||||||
gscope = self.__globals__
|
gscope = self.__globals__
|
||||||
|
@@ -209,7 +209,7 @@ class pointer_type(dtype):
|
|||||||
self.name = self.__str__()
|
self.name = self.__str__()
|
||||||
|
|
||||||
def to_ir(self, builder: ir.builder) -> ir.pointer_type:
|
def to_ir(self, builder: ir.builder) -> ir.pointer_type:
|
||||||
return ir.type.make_ptr(self.element_ty.to_ir(builder), 1)
|
return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'pointer<{self.element_ty}>'
|
return f'pointer<{self.element_ty}>'
|
||||||
@@ -247,7 +247,7 @@ class block_type(dtype):
|
|||||||
self.name = self.__str__()
|
self.name = self.__str__()
|
||||||
|
|
||||||
def to_ir(self, builder: ir.builder) -> ir.block_type:
|
def to_ir(self, builder: ir.builder) -> ir.block_type:
|
||||||
return ir.type.make_block(self.element_ty.to_ir(builder), self.shape)
|
return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'<{self.shape}, {self.element_ty}>'
|
return f'<{self.shape}, {self.element_ty}>'
|
||||||
@@ -275,8 +275,8 @@ class block_type(dtype):
|
|||||||
|
|
||||||
|
|
||||||
class function_type(dtype):
|
class function_type(dtype):
|
||||||
def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None:
|
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
|
||||||
self.ret_type = ret_type
|
self.ret_types = ret_types
|
||||||
self.param_types = param_types
|
self.param_types = param_types
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -284,7 +284,8 @@ class function_type(dtype):
|
|||||||
|
|
||||||
def to_ir(self, builder: ir.builder):
|
def to_ir(self, builder: ir.builder):
|
||||||
ir_param_types = [ty.to_ir(builder) for ty in self.param_types]
|
ir_param_types = [ty.to_ir(builder) for ty in self.param_types]
|
||||||
return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types)
|
ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
|
||||||
|
return builder.get_function_ty(ir_param_types, ret_types)
|
||||||
|
|
||||||
|
|
||||||
# scalar types
|
# scalar types
|
||||||
@@ -425,8 +426,8 @@ class tensor:
|
|||||||
self.handle = handle
|
self.handle = handle
|
||||||
# Block shape
|
# Block shape
|
||||||
self.shape = (1, )
|
self.shape = (1, )
|
||||||
if self.handle.type.is_block():
|
if type.is_block():
|
||||||
self.shape = self.handle.type.shape
|
self.shape = type.shape
|
||||||
self.numel = 1
|
self.numel = 1
|
||||||
for s in self.shape:
|
for s in self.shape:
|
||||||
self.numel *= s
|
self.numel *= s
|
||||||
|
@@ -17,11 +17,13 @@ i64_ty = builder.get_int64_ty()
|
|||||||
|
|
||||||
f16_ty = builder.get_half_ty()
|
f16_ty = builder.get_half_ty()
|
||||||
|
|
||||||
f16_ptr_ty = builder.get_ptr_ty(f16_ty)
|
f16_ptr_ty = builder.get_ptr_ty(f16_ty, 1)
|
||||||
|
|
||||||
func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], [])
|
func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], [])
|
||||||
func = builder.create_function('foo', func_ty)
|
func = builder.create_function('foo', func_ty)
|
||||||
|
|
||||||
|
module.push_back(func)
|
||||||
|
|
||||||
# ...
|
# ...
|
||||||
entry = func.add_entry_block()
|
entry = func.add_entry_block()
|
||||||
builder.set_insertion_point_to_start(entry)
|
builder.set_insertion_point_to_start(entry)
|
||||||
@@ -51,5 +53,5 @@ builder.create_store(c_ptrs, c)
|
|||||||
|
|
||||||
# func.dump()
|
# func.dump()
|
||||||
|
|
||||||
module.push_back(func)
|
|
||||||
module.dump()
|
module.dump()
|
||||||
|
Reference in New Issue
Block a user