diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 2c367ab1d6e..71ece847a63 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5317,6 +5317,7 @@ bool nir_scale_fdiv(nir_shader *shader); bool nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *data); bool nir_lower_alu_width(nir_shader *shader, nir_vectorize_cb cb, const void *data); +bool nir_lower_alu_vec8_16_srcs(nir_shader *shader); bool nir_lower_bool_to_bitsize(nir_shader *shader); bool nir_lower_bool_to_float(nir_shader *shader, bool has_fcsel_ne); bool nir_lower_bool_to_int32(nir_shader *shader); diff --git a/src/compiler/nir/nir_lower_alu_width.c b/src/compiler/nir/nir_lower_alu_width.c index 2704e299fb5..0d3f4e9feb3 100644 --- a/src/compiler/nir/nir_lower_alu_width.c +++ b/src/compiler/nir/nir_lower_alu_width.c @@ -454,3 +454,46 @@ nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void * return nir_lower_alu_width(shader, cb ? scalar_cb : NULL, &data); } + +static bool +lower_alu_vec8_16_src(nir_builder *b, nir_instr *instr, void *_data) +{ + if (instr->type != nir_instr_type_alu) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(instr); + const nir_op_info *info = &nir_op_infos[alu->op]; + + bool changed = false; + b->cursor = nir_before_instr(instr); + for (int i = 0; i < info->num_inputs; i++) { + if (alu->src[i].src.ssa->num_components < 8 || info->input_sizes[i]) + continue; + + changed = true; + nir_def *comps[4]; + for (int c = 0; c < alu->def.num_components; c++) { + unsigned swizzle = alu->src[i].swizzle[c]; + alu->src[i].swizzle[c] = c; + + nir_const_value *const_val = nir_src_as_const_value(alu->src[i].src); + if (const_val) { + comps[c] = nir_build_imm(b, 1, alu->src[i].src.ssa->bit_size, &const_val[swizzle]); + } else { + comps[c] = nir_swizzle(b, alu->src[i].src.ssa, &swizzle, 1); + } + } + nir_def *src = nir_vec(b, comps, alu->def.num_components); + nir_src_rewrite(&alu->src[i].src, src); + } + + return changed; +} + +bool +nir_lower_alu_vec8_16_srcs(nir_shader *shader) +{ + return nir_shader_instructions_pass(shader, lower_alu_vec8_16_src, + nir_metadata_block_index | nir_metadata_dominance, + NULL); +}