radv: Use correct writemask for cooperative matrix ordering.
Not expecting this to actually fix anything externally visible,
but reduces some invalid usage when the resulting vector is
not 16 elements long (e.g. the C/result matrix).
Fixes: 9df4703fbb ("radv: Add cooperative matrix lowering.")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26768>
This commit is contained in:
parent
16af090908
commit
07ad6fd34a
1 changed files with 12 additions and 8 deletions
|
|
@ -181,7 +181,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
|
||||
nir_def *elem = intr->src[1].ssa;
|
||||
nir_def *r = nir_vector_insert(&b, src1, elem, index);
|
||||
nir_store_deref(&b, dst_deref, r, 0xffff);
|
||||
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
@ -193,7 +193,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
|
||||
nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
|
||||
|
||||
nir_store_deref(&b, dst_deref, r, 0xffff);
|
||||
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
@ -253,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
}
|
||||
|
||||
nir_def *mat = nir_vec(&b, vars, length);
|
||||
nir_store_deref(&b, dst_deref, mat, 0xffff);
|
||||
nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
@ -332,7 +332,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
|
||||
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
|
||||
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
|
||||
nir_component_mask(ret->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
@ -366,7 +367,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
ret = nir_vec(&b, components, ret->num_components * 2);
|
||||
}
|
||||
|
||||
nir_store_deref(&b, dst_deref, ret, 0xffff);
|
||||
nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
@ -375,7 +376,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
|
||||
nir_component_mask(ret->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
@ -385,14 +387,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
nir_def *ret = nir_build_alu2(&b, op, src1, src2);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
|
||||
nir_component_mask(ret->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_bitcast: {
|
||||
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, 0xffff);
|
||||
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
|
||||
nir_component_mask(src1->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue