|
|
|
@@ -99,6 +99,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|
|
|
|
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
|
|
|
|
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
|
|
|
|
// constants
|
|
|
|
|
#define i16(...) builder_->getInt16(__VA_ARGS__)
|
|
|
|
|
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
|
|
|
|
// ops
|
|
|
|
|
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
|
|
|
@@ -854,6 +855,234 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int16_to_float16x8(
|
|
|
|
|
Value *in0, Value *scale_x512, Value *shift
|
|
|
|
|
){
|
|
|
|
|
/* unpacking 8 int2s packed into an int16 to 8 float16s
|
|
|
|
|
* the algorithm is similar to
|
|
|
|
|
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
|
|
|
|
|
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1492-L1563
|
|
|
|
|
*/
|
|
|
|
|
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
|
|
|
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
|
|
|
|
|
"{"
|
|
|
|
|
".reg .b32 a<2>, b<4>; \n\t" // input is 0xab,cd,ef,gh,ab,cd,ef,gh, each a, b etc occupies two bits.
|
|
|
|
|
"and.b32 a0, 0x30300303, $4; \n\t" // set a0 to 0x0b,00,0f,00,00,0d,00,0h
|
|
|
|
|
"and.b32 a1, 0xc0c00c0c, $4; \n\t" // set a1 to 0xa0,00,e0,00,00,c0,00,g0
|
|
|
|
|
"prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x00,00,00,0d,00,00,00,0h
|
|
|
|
|
"prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00,00,00,c0,00,00,00,g0
|
|
|
|
|
"prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x00,00,0b,00,00,00,0f,00
|
|
|
|
|
"prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00,00,a0,00,00,00,e0,00
|
|
|
|
|
"mov.b32 a0, 0x78007800; \n\t" // a0 = 32768
|
|
|
|
|
"mov.b32 a1, 0x70007000; \n\t" // a1 = 8192
|
|
|
|
|
"mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
|
|
|
|
|
"mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 8192.
|
|
|
|
|
"mov.b32 a0, 0x68006800; \n\t" // a0 = 2048
|
|
|
|
|
"mov.b32 a1, 0x60006000; \n\t" // a1 = 512
|
|
|
|
|
"mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 2048.
|
|
|
|
|
"mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 512.
|
|
|
|
|
"fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
|
|
|
|
|
"fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
|
|
|
|
|
"fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out2 = b2 * scale + shift.
|
|
|
|
|
"fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out3 = b3 * scale + shift.
|
|
|
|
|
"}", "=r,=r,=r,=r,r,r,r", false);
|
|
|
|
|
|
|
|
|
|
Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2));
|
|
|
|
|
packed_in = insert_elt(packed_in, in0, (int)0);
|
|
|
|
|
packed_in = insert_elt(packed_in, in0, (int)1);
|
|
|
|
|
Value *in = bit_cast(packed_in, i32_ty);
|
|
|
|
|
|
|
|
|
|
Value *ret = call(ptx, {in, scale_x512, shift});
|
|
|
|
|
Value *packed_ret0 = extract_val(ret, {0});
|
|
|
|
|
Value *packed_ret1 = extract_val(ret, {1});
|
|
|
|
|
Value *packed_ret2 = extract_val(ret, {2});
|
|
|
|
|
Value *packed_ret3 = extract_val(ret, {3});
|
|
|
|
|
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
|
|
|
|
|
Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
|
|
|
|
|
Value *ret2 = extract_elt(packed_ret2, (uint64_t)0); // f
|
|
|
|
|
Value *ret3 = extract_elt(packed_ret3, (uint64_t)0); // e
|
|
|
|
|
Value *ret4 = extract_elt(packed_ret0, (uint64_t)1); // d
|
|
|
|
|
Value *ret5 = extract_elt(packed_ret1, (uint64_t)1); // c
|
|
|
|
|
Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
|
|
|
|
|
Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
|
|
|
|
|
return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int32_to_float16x8(
|
|
|
|
|
Value *in0, Value *scale_x512, Value *shift
|
|
|
|
|
){
|
|
|
|
|
/* unpacking 8 int4s packed into an int32 to 8 float16s
|
|
|
|
|
* the algorithm is similar to
|
|
|
|
|
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
|
|
|
|
|
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619
|
|
|
|
|
*/
|
|
|
|
|
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
|
|
|
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
|
|
|
|
|
"{"
|
|
|
|
|
".reg .b32 a<2>, b<4>; \n\t"
|
|
|
|
|
"and.b32 a0, 0x0f0f0f0f, $4; \n\t" // If input is 0xabcdefgh set a to 0x0b0d0f0h
|
|
|
|
|
"and.b32 a1, 0xf0f0f0f0, $4; \n\t" // If input is 0xabcdefgh set a to 0xa0c0e0g0
|
|
|
|
|
"prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x000f000h
|
|
|
|
|
"prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00e000g0
|
|
|
|
|
"prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x000b000d
|
|
|
|
|
"prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00a000c0
|
|
|
|
|
"mov.b32 a0, 0x78007800; \n\t"
|
|
|
|
|
"mov.b32 a1, 0x68006800; \n\t"
|
|
|
|
|
"mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
|
|
|
|
|
"mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 2048.
|
|
|
|
|
"mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 32768.
|
|
|
|
|
"mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 2048.
|
|
|
|
|
"fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
|
|
|
|
|
"fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
|
|
|
|
|
"fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out0 = b0 * scale + shift.
|
|
|
|
|
"fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out1 = b1 * scale + shift.
|
|
|
|
|
"}", "=r,=r,=r,=r,r,r,r", false);
|
|
|
|
|
|
|
|
|
|
Value *ret = call(ptx, {in0, scale_x512, shift});
|
|
|
|
|
Value *packed_ret0 = extract_val(ret, {0});
|
|
|
|
|
Value *packed_ret1 = extract_val(ret, {1});
|
|
|
|
|
Value *packed_ret2 = extract_val(ret, {2});
|
|
|
|
|
Value *packed_ret3 = extract_val(ret, {3});
|
|
|
|
|
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
|
|
|
|
|
Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
|
|
|
|
|
Value *ret2 = extract_elt(packed_ret0, (uint64_t)1); // f
|
|
|
|
|
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // e
|
|
|
|
|
Value *ret4 = extract_elt(packed_ret2, (uint64_t)0); // d
|
|
|
|
|
Value *ret5 = extract_elt(packed_ret3, (uint64_t)0); // c
|
|
|
|
|
Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
|
|
|
|
|
Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
|
|
|
|
|
return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Value*, Value*, Value*, Value*> generator::int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift){
|
|
|
|
|
/* unpacking 4 int8s packed into an int32 to 4 fp16s
|
|
|
|
|
* the algorithm is similar to
|
|
|
|
|
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
|
|
|
|
|
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1622-L1646
|
|
|
|
|
*/
|
|
|
|
|
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
|
|
|
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
|
|
|
|
|
"{"
|
|
|
|
|
".reg .b32 a, b<2>; \n\t"
|
|
|
|
|
"prmt.b32 b0, 0, $2, 0x0504; \n\t" // If input is 0xabcdefgh set b0 to 0x00ef00gh
|
|
|
|
|
"prmt.b32 b1, 0, $2, 0x0706; \n\t" // If input is 0xabcdefgh set b1 to 0x00ab00cd
|
|
|
|
|
"mov.b32 a, 0x78007800; \n\t"
|
|
|
|
|
"mul.f16x2 b0, b0, a; \n\t" // b0 = b0 * 32768.
|
|
|
|
|
"mul.f16x2 b1, b1, a; \n\t" // b1 = b1 * 32768.
|
|
|
|
|
"fma.rn.f16x2 $0, b0, $3, $4; \n\t" // out0 = b0 * scale + shift.
|
|
|
|
|
"fma.rn.f16x2 $1, b1, $3, $4; \n\t" // out1 = b1 * scale + shift.
|
|
|
|
|
"}", "=r,=r,r,r,r", false);
|
|
|
|
|
|
|
|
|
|
Value *ret = call(ptx, {in0, scale_x512, shift});
|
|
|
|
|
Value *packed_ret0 = extract_val(ret, {0});
|
|
|
|
|
Value *packed_ret1 = extract_val(ret, {1});
|
|
|
|
|
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // gh
|
|
|
|
|
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); // ef
|
|
|
|
|
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); // cd
|
|
|
|
|
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // ab
|
|
|
|
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Value*, Value*> generator::prepare_scale_shift(Value *scale, Value *shift){
|
|
|
|
|
Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty));
|
|
|
|
|
Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2));
|
|
|
|
|
p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0);
|
|
|
|
|
p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1);
|
|
|
|
|
p_scale_x512 = bit_cast(p_scale_x512, i32_ty);
|
|
|
|
|
|
|
|
|
|
Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2));
|
|
|
|
|
p_shift = insert_elt(p_shift, shift, (int)0);
|
|
|
|
|
p_shift = insert_elt(p_shift, shift, (int)1);
|
|
|
|
|
p_shift = bit_cast(p_shift, i32_ty);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(p_scale_x512, p_shift);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* \brief Code Generation for `dequantize`
|
|
|
|
|
*/
|
|
|
|
|
void generator::visit_dequantize_inst(ir::dequantize_inst* x) {
|
|
|
|
|
ir::value *op = x->get_operand(0);
|
|
|
|
|
|
|
|
|
|
auto src_ty_size_in_bits = op->get_type()->get_scalar_ty()->get_primitive_size_in_bits();
|
|
|
|
|
|
|
|
|
|
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
|
|
|
|
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
|
|
|
|
|
|
|
|
|
auto x_idxs = idxs_.at(x);
|
|
|
|
|
auto op_idxs = idxs_.at(op);
|
|
|
|
|
|
|
|
|
|
ir::value *scale = x->get_operand(1);
|
|
|
|
|
ir::value *shift = x->get_operand(2);
|
|
|
|
|
|
|
|
|
|
Value *p_scale_x512, *p_shift;
|
|
|
|
|
std::tie(p_scale_x512, p_shift) = prepare_scale_shift(vals_[scale][{}], vals_[shift][{}]);
|
|
|
|
|
|
|
|
|
|
int ld = layouts_->get(x)->get_order(0);
|
|
|
|
|
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
|
|
|
|
|
|
|
|
|
int op_ld = layouts_->get(op)->get_order(0);
|
|
|
|
|
int op_contiguous = layouts_->get(op)->to_scanline()->nts(op_ld);
|
|
|
|
|
|
|
|
|
|
std::string err_msg;
|
|
|
|
|
err_msg = "unsupported dequantization, cannot vectorize properly. x_idxs.size(): "
|
|
|
|
|
+ std::to_string(x_idxs.size()) + "; op_idxs.size(): "
|
|
|
|
|
+ std::to_string(op_idxs.size()) + "; contiguous: "
|
|
|
|
|
+ std::to_string(contiguous) + "; op_contiguous: "
|
|
|
|
|
+ std::to_string(op_contiguous) + ". if the condition "
|
|
|
|
|
"is not met, please try adjusting block_size, num_warps or "
|
|
|
|
|
"using tl.multiple_of to hint the input/output ptr address.";
|
|
|
|
|
|
|
|
|
|
if (ret_last_dim == 8 * op_last_dim) {
|
|
|
|
|
if((x_idxs.size() != 8 * op_idxs.size()) || (contiguous != 8 * op_contiguous)) {
|
|
|
|
|
throw std::runtime_error(err_msg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto cvt = [&](
|
|
|
|
|
Value* a, Value* scale, Value* shift
|
|
|
|
|
){
|
|
|
|
|
if (src_ty_size_in_bits == 16){ // int2 quantization, int16 to 8 fp16s
|
|
|
|
|
return int16_to_float16x8(a, scale, shift);
|
|
|
|
|
} else if (src_ty_size_in_bits == 32) { // int4 quantization, int32 to 8 fp16s
|
|
|
|
|
return int32_to_float16x8(a, scale, shift);
|
|
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error("unsupported conversion");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for(size_t j = 0; j < op_idxs.size(); j++){
|
|
|
|
|
size_t i = j * 8;
|
|
|
|
|
std::tie(vals_[x][x_idxs[i+0]],
|
|
|
|
|
vals_[x][x_idxs[i+1]],
|
|
|
|
|
vals_[x][x_idxs[i+2]],
|
|
|
|
|
vals_[x][x_idxs[i+3]],
|
|
|
|
|
vals_[x][x_idxs[i+4]],
|
|
|
|
|
vals_[x][x_idxs[i+5]],
|
|
|
|
|
vals_[x][x_idxs[i+6]],
|
|
|
|
|
vals_[x][x_idxs[i+7]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
|
|
|
|
|
}
|
|
|
|
|
} else if (ret_last_dim == 4 * op_last_dim && src_ty_size_in_bits == 32) { // int8 quantization, int32 to 4 fp16s
|
|
|
|
|
if((x_idxs.size() != 4 * op_idxs.size()) || (contiguous != 4 * op_contiguous)) {
|
|
|
|
|
throw std::runtime_error(err_msg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto cvt = [&](Value* a, Value* scale, Value* shift){
|
|
|
|
|
return int32_to_float16x4(a, scale, shift);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for(size_t j = 0; j < op_idxs.size(); j++){
|
|
|
|
|
size_t i = j * 4;
|
|
|
|
|
std::tie(vals_[x][x_idxs[i+0]],
|
|
|
|
|
vals_[x][x_idxs[i+1]],
|
|
|
|
|
vals_[x][x_idxs[i+2]],
|
|
|
|
|
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error("unsupported dequantization");
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* \brief Code Generation for `return`
|
|
|
|
|
*/
|
|
|
|
@@ -907,7 +1136,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|
|
|
|
|
|
|
|
|
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
|
|
|
|
|
// TODO: generalize
|
|
|
|
|
is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
|
|
|
|
is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
|
|
|
|
(a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1));
|
|
|
|
|
if(is_mma_first_row)
|
|
|
|
|
vec = std::min<size_t>(2, aln);
|
|
|
|
@@ -1009,7 +1238,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|
|
|
|
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
|
|
|
|
for(Value *v: others)
|
|
|
|
|
arg_tys.push_back(v->getType());
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
arg_tys.push_back(i64_ty);
|
|
|
|
|
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
|
|
|
|
|
// ---
|
|
|
|
@@ -1025,7 +1254,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|
|
|
|
asm_cstrt += ",";
|
|
|
|
|
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
|
|
|
|
}
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
asm_cstrt += ",l";
|
|
|
|
|
// ---
|
|
|
|
|
// finally call inline ASM
|
|
|
|
@@ -1036,8 +1265,8 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|
|
|
|
args.push_back(v);
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
args.push_back(policies_.at(x->get_eviction_policy()));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value *_ret = call(inlineAsm, args);
|
|
|
|
|
// if(!op->get_type()->is_block_ty()){
|
|
|
|
|
// Value* cond = icmp_eq(tid, i32(0));
|
|
|
|
@@ -1050,7 +1279,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|
|
|
|
// _ret = load(shptr);
|
|
|
|
|
// add_barrier();
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// ---
|
|
|
|
|
// extract and store return values
|
|
|
|
|
// ---
|
|
|
|
@@ -1104,7 +1333,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
|
|
|
|
// vec = std::min(nts, aln);
|
|
|
|
|
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
|
|
|
|
|
// TODO: generalize
|
|
|
|
|
bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
|
|
|
|
bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
|
|
|
|
(a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1));
|
|
|
|
|
if(is_mma_first_row)
|
|
|
|
|
vec = std::min<size_t>(2, aln);
|
|
|
|
@@ -1166,7 +1395,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
|
|
|
|
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
|
|
|
|
for(int ii = 0; ii < n_words; ii++)
|
|
|
|
|
arg_tys.push_back(val_arg_ty);
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
arg_tys.push_back(i64_ty);
|
|
|
|
|
FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false);
|
|
|
|
|
// ---
|
|
|
|
@@ -1177,7 +1406,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
|
|
|
|
asm_cstrt += ",";
|
|
|
|
|
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
|
|
|
|
}
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
if (has_l2_evict_policy)
|
|
|
|
|
asm_cstrt += ",l";
|
|
|
|
|
// ---
|
|
|
|
|
// finally call inline ASM
|
|
|
|
@@ -1817,13 +2046,13 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
|
|
|
|
namespace {
|
|
|
|
|
class mma16816_smem_loader {
|
|
|
|
|
public:
|
|
|
|
|
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
|
|
|
|
|
std::vector<unsigned> tile_shape,
|
|
|
|
|
std::vector<int> instr_shape, std::vector<int> mat_shape,
|
|
|
|
|
int per_phase, int max_phase, int dtsize, Builder *builder,
|
|
|
|
|
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
|
|
|
|
|
std::vector<unsigned> tile_shape,
|
|
|
|
|
std::vector<int> instr_shape, std::vector<int> mat_shape,
|
|
|
|
|
int per_phase, int max_phase, int dtsize, Builder *builder,
|
|
|
|
|
adder add, multiplier mul, geper gep)
|
|
|
|
|
: wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
|
|
|
|
|
instr_shape_(instr_shape), mat_shape_(mat_shape),
|
|
|
|
|
instr_shape_(instr_shape), mat_shape_(mat_shape),
|
|
|
|
|
per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
|
|
|
|
|
add(add), mul(mul), gep(gep) {
|
|
|
|
|
// compute compile-time constant variables & types
|
|
|
|
@@ -1837,7 +2066,7 @@ public:
|
|
|
|
|
need_trans_ = k_order_ != order_[0];
|
|
|
|
|
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
|
|
|
|
|
|
|
|
|
|
// we need more pointers at the fast-changing axis,
|
|
|
|
|
// we need more pointers at the fast-changing axis,
|
|
|
|
|
if (can_use_ldmatrix_)
|
|
|
|
|
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
|
|
|
|
|
else // warning: this only works for tf32 & need transpose
|
|
|
|
@@ -1873,7 +2102,7 @@ public:
|
|
|
|
|
Value *s0 = urem(s, i32(2));
|
|
|
|
|
Value *s1 = udiv(s, i32(2));
|
|
|
|
|
|
|
|
|
|
// We use different orders for a & b for better performance.
|
|
|
|
|
// We use different orders for a & b for better performance.
|
|
|
|
|
Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
|
|
|
|
|
Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
|
|
|
|
|
mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
|
|
|
|
@@ -1884,7 +2113,7 @@ public:
|
|
|
|
|
Value *s_mat_off = mat_off[order_[1]];
|
|
|
|
|
// offset inside a matrix
|
|
|
|
|
Value *s_off_in_mat = c;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<Value*> offs(num_ptr_);
|
|
|
|
|
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
|
|
|
|
// pre-compute strided offset
|
|
|
|
@@ -1898,7 +2127,7 @@ public:
|
|
|
|
|
} else if (dtsize_ == 4 && need_trans_) {
|
|
|
|
|
// load tf32 matrices with lds32
|
|
|
|
|
Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
|
|
|
|
|
Value *s_off_in_mat = urem(lane, i32(4)); //
|
|
|
|
|
Value *s_off_in_mat = urem(lane, i32(4)); //
|
|
|
|
|
|
|
|
|
|
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
|
|
|
|
std::vector<Value*> offs(num_ptr_);
|
|
|
|
@@ -1945,7 +2174,7 @@ public:
|
|
|
|
|
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
|
|
|
|
|
mul(nk_mat_arr, i32(mat_arr_stride_)));
|
|
|
|
|
Value *s_mat_off = k_mat_arr; // always 0?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) {
|
|
|
|
|
for (int elem_off = 0; elem_off < 4; ++elem_off) {
|
|
|
|
|
int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off;
|
|
|
|
@@ -1971,10 +2200,10 @@ public:
|
|
|
|
|
throw std::runtime_error("invalid smem load config");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Value*, Value*, Value*, Value*>
|
|
|
|
|
std::tuple<Value*, Value*, Value*, Value*>
|
|
|
|
|
load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
|
|
|
|
|
Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
|
|
|
|
|
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
|
|
|
|
|
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
|
|
|
|
|
std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
|
|
|
|
|
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
|
|
|
|
|
int mat_idx[2] = {mat0, mat1};
|
|
|
|
@@ -2006,7 +2235,7 @@ public:
|
|
|
|
|
std::string trans = need_trans_ ? ".trans" : "";
|
|
|
|
|
// the offset (in byte) on the strided axis is a constant
|
|
|
|
|
int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
|
|
|
|
|
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
|
|
|
|
|
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
|
|
|
|
|
"ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
|
|
|
|
|
"{$0, $1, $2, $3}, "
|
|
|
|
|
"[$4 + " + std::to_string(s_offset) + "];",
|
|
|
|
@@ -2015,7 +2244,7 @@ public:
|
|
|
|
|
res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
|
|
|
|
|
if (k == 0 && inc == 1 && is_prefetch)
|
|
|
|
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
|
|
|
|
|
return {extract_val(res_v4, std::vector<unsigned>{0}),
|
|
|
|
|
return {extract_val(res_v4, std::vector<unsigned>{0}),
|
|
|
|
|
extract_val(res_v4, std::vector<unsigned>{1}),
|
|
|
|
|
extract_val(res_v4, std::vector<unsigned>{2}),
|
|
|
|
|
extract_val(res_v4, std::vector<unsigned>{3})};
|
|
|
|
@@ -2062,13 +2291,13 @@ public:
|
|
|
|
|
Value *i32_elems[4];
|
|
|
|
|
for (int i=0; i<4; ++i)
|
|
|
|
|
i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value *elem00, *elem01, *elem02, *elem03;
|
|
|
|
|
Value *elem10, *elem11, *elem12, *elem13;
|
|
|
|
|
Value *elem20, *elem21, *elem22, *elem23;
|
|
|
|
|
Value *elem30, *elem31, *elem32, *elem33;
|
|
|
|
|
Value *i8_elems[4*4];
|
|
|
|
|
if (k_order_ == 1) { //
|
|
|
|
|
if (k_order_ == 1) { //
|
|
|
|
|
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
|
|
|
|
|
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
|
|
|
|
|
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
|
|
|
|
@@ -2155,7 +2384,7 @@ private:
|
|
|
|
|
int s_mat_stride_;
|
|
|
|
|
// stride when moving to next not-k mat
|
|
|
|
|
int warp_off_stride_;
|
|
|
|
|
int mat_arr_stride_; // matrix arrangement (inside a load) stride
|
|
|
|
|
int mat_arr_stride_; // matrix arrangement (inside a load) stride
|
|
|
|
|
bool need_trans_, can_use_ldmatrix_;
|
|
|
|
|
int num_ptr_;
|
|
|
|
|
|
|
|
|
@@ -2232,7 +2461,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|
|
|
|
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
|
|
|
|
smem_ptr_ty = ptr_ty(f16_ty, 3);
|
|
|
|
|
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
|
|
|
|
phi_ty = fp16x2_ty;
|
|
|
|
|
phi_ty = fp16x2_ty;
|
|
|
|
|
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
|
|
|
|
|
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
|
|
|
|
smem_ptr_ty = ptr_ty(bf16_ty, 3);
|
|
|
|
@@ -2303,8 +2532,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|
|
|
|
if(is_a_shared) {
|
|
|
|
|
const int per_phase_a = swizzle_->get_per_phase(layout_a);
|
|
|
|
|
const int max_phase_a = swizzle_->get_max_phase(layout_a);
|
|
|
|
|
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
|
|
|
|
|
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
|
|
|
|
|
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
|
|
|
|
|
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
|
|
|
|
|
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
|
|
|
|
|
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
|
|
|
|
|
int num_ptr_a = a_loader.get_num_ptr();
|
|
|
|
@@ -2319,7 +2548,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|
|
|
|
// loading function
|
|
|
|
|
load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable {
|
|
|
|
|
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
|
|
|
|
|
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
|
|
|
|
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
|
|
|
|
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
|
|
|
|
register_lds2(ha, m, k, inc, ha0, is_prefetch);
|
|
|
|
|
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
|
|
|
|
@@ -2389,12 +2618,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|
|
|
|
for(int i = 0; i < num_ptr_b; i++)
|
|
|
|
|
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// loading function
|
|
|
|
|
std::function<void(int,int,int,bool)> load_b;
|
|
|
|
|
load_b = [&](int n, int k, int inc, bool is_prefetch) {
|
|
|
|
|
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
|
|
|
|
|
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
|
|
|
|
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
|
|
|
|
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
|
|
|
|
register_lds2(hb, n, k, inc, hb0, is_prefetch);
|
|
|
|
|
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
|
|
|
|
@@ -2419,7 +2648,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|
|
|
|
(m + 1)*cols_per_thread + (n*2 + 0),
|
|
|
|
|
(m + 1)*cols_per_thread + (n*2 + 1)
|
|
|
|
|
};
|
|
|
|
|
Value *nc = call(mma_ty, mma_fn,
|
|
|
|
|
Value *nc = call(mma_ty, mma_fn,
|
|
|
|
|
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
|
|
|
|
|
hb[{n, k}], hb[{n, k+1}],
|
|
|
|
|
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
|
|
|
@@ -2608,7 +2837,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
|
|
|
|
return visit_mma884(dot, A, B, D, NK);
|
|
|
|
|
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
|
|
|
|
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
|
|
|
|
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
|
|
|
|
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
|
|
|
|
A->get_type()->get_scalar_ty()->is_fp32_ty())
|
|
|
|
|
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
|
|
|
|
throw std::runtime_error("dot has invalid operand type");
|
|
|
|
@@ -2710,7 +2939,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
|
|
|
|
warps_per_inner = layout->to_mma()->wpt(1);
|
|
|
|
|
col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size();
|
|
|
|
|
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert(warp_j != nullptr);
|
|
|
|
|
|
|
|
|
|
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
|
|
|
|
@@ -3367,7 +3596,7 @@ void generator::visit_constant_fp(ir::constant_fp *x){
|
|
|
|
|
if (x->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
|
|
|
|
// highest 16 bits of fp32
|
|
|
|
|
float fp32_value = x->get_value();
|
|
|
|
|
uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
|
|
|
|
|
uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
|
|
|
|
|
& 0xffff0000) >> 16;
|
|
|
|
|
std::stringstream const_str;
|
|
|
|
|
const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned
|
|
|
|
|