radv: Use correct writemask for cooperative matrix ordering.

Not expecting this to actually fix anything externally visible,
but reduces some invalid usage when the resulting vector is
not 16 elements long (e.g. the C/result matrix).

Fixes: 9df4703fbb ("radv: Add cooperative matrix lowering.")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26768>
This commit is contained in:
Bas Nieuwenhuizen 2023-12-20 00:19:55 +01:00 committed by Marge Bot
parent 16af090908
commit 07ad6fd34a

View file

@ -181,7 +181,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *elem = intr->src[1].ssa;
nir_def *r = nir_vector_insert(&b, src1, elem, index);
nir_store_deref(&b, dst_deref, r, 0xffff);
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -193,7 +193,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
nir_store_deref(&b, dst_deref, r, 0xffff);
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -253,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
}
nir_def *mat = nir_vec(&b, vars, length);
nir_store_deref(&b, dst_deref, mat, 0xffff);
nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -332,7 +332,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -366,7 +367,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
ret = nir_vec(&b, components, ret->num_components * 2);
}
nir_store_deref(&b, dst_deref, ret, 0xffff);
nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -375,7 +376,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@ -385,14 +387,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, src2);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
}
case nir_intrinsic_cmat_bitcast: {
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
nir_component_mask(src1->num_components));
nir_instr_remove(instr);
progress = true;
break;