diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index 6d31342376b..8ed937baaed 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -678,7 +678,7 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_def *scalar = intrin->src[1].ssa; nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[2]); - nir_def *dst_index = intrin->src[3].ssa; + const nir_src dst_index = intrin->src[3]; const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice); @@ -691,24 +691,34 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) const unsigned packing_factor = get_packing_factor(desc, dst_slice->type); const unsigned num_components = glsl_get_vector_elements(dst_slice->type); - nir_def *slice_index = nir_udiv_imm(b, dst_index, packing_factor); - nir_def *vector_index = nir_umod_imm(b, dst_index, packing_factor); + nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, packing_factor); + nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, packing_factor); nir_def *results[NIR_MAX_VEC_COMPONENTS]; + const int slice_constant_index = nir_src_is_const(dst_index) + ? nir_src_as_uint(dst_index) / packing_factor + : -1; + for (unsigned i = 0; i < num_components; i++) { nir_def *val = nir_channel(b, nir_load_deref(b, src_slice), i); nir_def *insert; - if (packing_factor == 1) { - insert = scalar; - } else { - nir_def *unpacked = nir_unpack_bits(b, val, bits); - nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index); + if (slice_constant_index < 0 || slice_constant_index == i) { + if (packing_factor == 1) { + insert = scalar; + } else { + nir_def *unpacked = nir_unpack_bits(b, val, bits); + nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index); - insert = nir_pack_bits(b, v, bits * packing_factor); + insert = nir_pack_bits(b, v, bits * packing_factor); + } + } else { + insert = val; } - results[i] = nir_bcsel(b, nir_ieq_imm(b, slice_index, i), insert, val); + results[i] = slice_constant_index < 0 + ? nir_bcsel(b, nir_ieq_imm(b, slice_index, i), insert, val) + : insert; } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),