aco, radv: vectorize f2f16 if rounding mode is rtz
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25952>
This commit is contained in:
parent
b781bd478c
commit
ab87831ae8
2 changed files with 40 additions and 0 deletions
|
|
@ -1366,6 +1366,28 @@ usub32_sat(Builder& bld, Definition dst, Temp src0, Temp src1)
|
|||
return dst.getTemp();
|
||||
}
|
||||
|
||||
void
|
||||
emit_vec2_f2f16(isel_context* ctx, nir_alu_instr* instr, Temp dst)
|
||||
{
|
||||
Builder bld(ctx->program, ctx->block);
|
||||
Temp src = get_ssa_temp(ctx, instr->src[0].src.ssa);
|
||||
RegClass rc = RegClass(src.regClass().type(), instr->src[0].src.ssa->bit_size / 32);
|
||||
Temp src0 = emit_extract_vector(ctx, src, instr->src[0].swizzle[0], rc);
|
||||
Temp src1 = emit_extract_vector(ctx, src, instr->src[0].swizzle[1], rc);
|
||||
|
||||
if (instr->src[0].src.ssa->bit_size == 64) {
|
||||
src0 = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src0);
|
||||
src1 = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src1);
|
||||
}
|
||||
|
||||
src1 = as_vgpr(ctx, src1);
|
||||
if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9)
|
||||
bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32_e64, Definition(dst), src0, src1);
|
||||
else
|
||||
bld.vop2(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1);
|
||||
emit_split_vector(ctx, dst, 2);
|
||||
}
|
||||
|
||||
void
|
||||
visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
||||
{
|
||||
|
|
@ -2892,6 +2914,14 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
|||
}
|
||||
case nir_op_f2f16:
|
||||
case nir_op_f2f16_rtne: {
|
||||
if (instr->def.num_components == 2) {
|
||||
/* Vectorizing f2f16 is only possible with rtz. */
|
||||
assert(instr->op != nir_op_f2f16_rtne);
|
||||
assert(ctx->block->fp_mode.round16_64 == fp_round_tz ||
|
||||
!ctx->block->fp_mode.care_about_round16_64);
|
||||
emit_vec2_f2f16(ctx, instr, dst);
|
||||
break;
|
||||
}
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 64)
|
||||
src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src);
|
||||
|
|
@ -2905,6 +2935,10 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
|||
break;
|
||||
}
|
||||
case nir_op_f2f16_rtz: {
|
||||
if (instr->def.num_components == 2) {
|
||||
emit_vec2_f2f16(ctx, instr, dst);
|
||||
break;
|
||||
}
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
if (instr->src[0].src.ssa->bit_size == 64)
|
||||
src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src);
|
||||
|
|
|
|||
|
|
@ -483,6 +483,11 @@ opt_vectorize_callback(const nir_instr *instr, const void *_)
|
|||
return 1;
|
||||
|
||||
switch (alu->op) {
|
||||
case nir_op_f2f16: {
|
||||
nir_shader *shader = nir_cf_node_get_function(&instr->block->cf_node)->function->shader;
|
||||
unsigned execution_mode = shader->info.float_controls_execution_mode;
|
||||
return nir_is_rounding_mode_rtz(execution_mode, 16) ? 2 : 1;
|
||||
}
|
||||
case nir_op_fadd:
|
||||
case nir_op_fsub:
|
||||
case nir_op_fmul:
|
||||
|
|
@ -494,6 +499,7 @@ opt_vectorize_callback(const nir_instr *instr, const void *_)
|
|||
case nir_op_fsat:
|
||||
case nir_op_fmin:
|
||||
case nir_op_fmax:
|
||||
case nir_op_f2f16_rtz:
|
||||
case nir_op_iabs:
|
||||
case nir_op_iadd:
|
||||
case nir_op_iadd_sat:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue