From 07ad6fd34a6ed32b74a3f9697545261a3fd84de2 Mon Sep 17 00:00:00 2001 From: Bas Nieuwenhuizen Date: Wed, 20 Dec 2023 00:19:55 +0100 Subject: [PATCH] radv: Use correct writemask for cooperative matrix ordering. Not expecting this to actually fix anything externally visible, but reduces some invalid usage when the resulting vector is not 16 elements long (e.g. the C/result matrix). Fixes: 9df4703fbb5 ("radv: Add cooperative matrix lowering.") Part-of: --- .../nir/radv_nir_lower_cooperative_matrix.c | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index d81231b0137..e882100e141 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -181,7 +181,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *elem = intr->src[1].ssa; nir_def *r = nir_vector_insert(&b, src1, elem, index); - nir_store_deref(&b, dst_deref, r, 0xffff); + nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components)); nir_instr_remove(instr); progress = true; break; @@ -193,7 +193,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size)); - nir_store_deref(&b, dst_deref, r, 0xffff); + nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components)); nir_instr_remove(instr); progress = true; break; @@ -253,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) } nir_def *mat = nir_vec(&b, vars, length); - nir_store_deref(&b, dst_deref, mat, 0xffff); + nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components)); nir_instr_remove(instr); progress = true; break; @@ -332,7 +332,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr), .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr)); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, + nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; @@ -366,7 +367,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) ret = nir_vec(&b, components, ret->num_components * 2); } - nir_store_deref(&b, dst_deref, ret, 0xffff); + nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; @@ -375,7 +376,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, + nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; @@ -385,14 +387,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, src2); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, + nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_bitcast: { nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, + nir_component_mask(src1->num_components)); nir_instr_remove(instr); progress = true; break;