2022-03-23 13:32:12 +08:00
|
|
|
import triton._C.libtriton.triton.ir as ir
|
|
|
|
|
|
|
|
ctx = ir.context()
|
|
|
|
ctx.load_triton()
|
|
|
|
|
|
|
|
# TODO
|
|
|
|
builder = ir.builder(ctx)
|
|
|
|
|
2022-03-30 13:32:52 +08:00
|
|
|
module = builder.create_module()
|
2022-03-23 13:32:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
i1_ty = builder.get_int1_ty()
|
|
|
|
i8_ty = builder.get_int8_ty()
|
|
|
|
i16_ty = builder.get_int16_ty()
|
|
|
|
i32_ty = builder.get_int32_ty()
|
|
|
|
i64_ty = builder.get_int64_ty()
|
|
|
|
|
|
|
|
f16_ty = builder.get_half_ty()
|
|
|
|
|
2022-03-30 16:31:03 +08:00
|
|
|
f16_ptr_ty = builder.get_ptr_ty(f16_ty, 1)
|
2022-03-23 13:32:12 +08:00
|
|
|
|
|
|
|
func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], [])
|
|
|
|
func = builder.create_function('foo', func_ty)
|
|
|
|
|
2022-03-30 16:31:03 +08:00
|
|
|
module.push_back(func)
|
|
|
|
|
2022-03-23 13:32:12 +08:00
|
|
|
# ...
|
|
|
|
entry = func.add_entry_block()
|
|
|
|
builder.set_insertion_point_to_start(entry)
|
|
|
|
offsets = builder.create_make_range(0, 128)
|
|
|
|
pid = builder.create_get_program_id(0)
|
|
|
|
_128 = builder.get_int32(128)
|
|
|
|
offset = builder.create_add(pid, _128)
|
|
|
|
offset = builder.create_broadcast(offset, [128])
|
|
|
|
offsets = builder.create_add(offset, offsets)
|
|
|
|
|
|
|
|
|
|
|
|
a_ptrs = builder.create_broadcast(entry.arg(0), [128])
|
|
|
|
b_ptrs = builder.create_broadcast(entry.arg(1), [128])
|
|
|
|
|
|
|
|
a_ptrs = builder.create_gep(a_ptrs, offsets)
|
|
|
|
b_ptrs = builder.create_gep(b_ptrs, offsets)
|
|
|
|
|
|
|
|
a = builder.create_load(a_ptrs)
|
|
|
|
b = builder.create_load(b_ptrs)
|
|
|
|
|
|
|
|
c = builder.create_fadd(a, b)
|
|
|
|
# c.set_attr("ieee_rounding", builder.get_bool_attr(True))
|
|
|
|
|
|
|
|
c_ptrs = builder.create_broadcast(entry.arg(2), [128])
|
|
|
|
c_ptrs = builder.create_gep(c_ptrs, offsets)
|
|
|
|
builder.create_store(c_ptrs, c)
|
|
|
|
|
2022-03-30 13:32:52 +08:00
|
|
|
# func.dump()
|
2022-03-23 13:32:12 +08:00
|
|
|
|
2022-03-30 16:31:03 +08:00
|
|
|
|
2022-03-30 13:32:52 +08:00
|
|
|
module.dump()
|