diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 5148295dddc..d94935f7296 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -638,6 +638,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st .lower_quad_broadcast_dynamic_to_const = gfx7minus, .lower_shuffle_to_swizzle_amd = 1, .lower_ballot_bit_count_to_mbcnt_amd = 1, + .lower_inverse_ballot = 1, }); NIR_PASS(_, nir, nir_lower_load_const_to_scalar); diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 6bc9f85b21b..119b9efc105 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5369,6 +5369,7 @@ typedef struct nir_lower_subgroups_options { bool lower_read_invocation_to_cond : 1; bool lower_rotate_to_shuffle : 1; bool lower_ballot_bit_count_to_mbcnt_amd : 1; + bool lower_inverse_ballot : 1; } nir_lower_subgroups_options; bool nir_lower_subgroups(nir_shader *shader, diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index 4e39db80ed7..e32c75a6bf2 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -478,6 +478,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr) break; /* Intrinsics which are always divergent */ + case nir_intrinsic_inverse_ballot: case nir_intrinsic_load_color0: case nir_intrinsic_load_color1: case nir_intrinsic_load_param: diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 05370121c92..576f9846ebb 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -446,9 +446,11 @@ intrinsic("read_invocation_cond_ir3", src_comp=[0, 1], dest_comp=0, flags=[CAN_E # # OpGroupNonUniformElect # OpSubgroupFirstInvocationKHR +# OpGroupNonUniformInverseBallot intrinsic("elect", dest_comp=1, flags=[CAN_ELIMINATE]) intrinsic("first_invocation", dest_comp=1, bit_sizes=[32], flags=[CAN_ELIMINATE]) intrinsic("last_invocation", dest_comp=1, bit_sizes=[32], flags=[CAN_ELIMINATE]) +intrinsic("inverse_ballot", src_comp=[0], dest_comp=1, flags=[CAN_ELIMINATE]) barrier("begin_invocation_interlock") barrier("end_invocation_interlock") diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 2177634bbed..571f127c5b4 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -683,6 +683,16 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) intrin->def.bit_size); } + case nir_intrinsic_inverse_ballot: + if (options->lower_inverse_ballot) { + return nir_ballot_bitfield_extract(b, 1, intrin->src[0].ssa, + nir_load_subgroup_invocation(b)); + } else if (intrin->src[0].ssa->num_components != options->ballot_components || + intrin->src[0].ssa->bit_size != options->ballot_bit_size) { + return nir_inverse_ballot(b, 1, ballot_type_to_uint(b, intrin->src[0].ssa, options)); + } + break; + case nir_intrinsic_ballot_bitfield_extract: case nir_intrinsic_ballot_bit_count_reduce: case nir_intrinsic_ballot_find_lsb: diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index ea079c7fd7a..de18259a199 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -103,22 +103,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, } case SpvOpGroupNonUniformInverseBallot: { - /* This one is just a BallotBitfieldExtract with subgroup invocation. - * We could add a NIR intrinsic but it's easier to just lower it on the - * spot. - */ - nir_intrinsic_instr *intrin = - nir_intrinsic_instr_create(b->nb.shader, - nir_intrinsic_ballot_bitfield_extract); - - intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4])); - intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb)); - - nir_def_init_for_type(&intrin->instr, &intrin->def, - dest_type->type); - nir_builder_instr_insert(&b->nb, &intrin->instr); - - vtn_push_nir_ssa(b, w[2], &intrin->def); + nir_def *dest = nir_inverse_ballot(&b->nb, 1, vtn_get_nir_ssa(b, w[4])); + vtn_push_nir_ssa(b, w[2], dest); break; } diff --git a/src/freedreno/ir3/ir3_nir.c b/src/freedreno/ir3/ir3_nir.c index d9ec05a0210..567dbc7ba49 100644 --- a/src/freedreno/ir3/ir3_nir.c +++ b/src/freedreno/ir3/ir3_nir.c @@ -564,6 +564,7 @@ ir3_nir_post_finalize(struct ir3_shader *shader) .lower_read_invocation_to_cond = true, .lower_shuffle = true, .lower_relative_shuffle = true, + .lower_inverse_ballot = true, }; if (!((s->info.stage == MESA_SHADER_COMPUTE) || diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.c b/src/gallium/auxiliary/gallivm/lp_bld_nir.c index ce49783da8b..a753844f868 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.c @@ -2996,6 +2996,7 @@ lp_build_opt_nir(struct nir_shader *nir) .lower_to_scalar = true, .lower_subgroup_masks = true, .lower_relative_shuffle = true, + .lower_inverse_ballot = true, }; NIR_PASS(progress, nir, nir_lower_subgroups, &subgroups_options); } while (progress); diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c index 11d5ee40a5b..fce4a05d034 100644 --- a/src/gallium/drivers/radeonsi/si_shader_nir.c +++ b/src/gallium/drivers/radeonsi/si_shader_nir.c @@ -254,6 +254,7 @@ const nir_lower_subgroups_options si_nir_subgroups_options = { .lower_subgroup_masks = true, .lower_vote_trivial = false, .lower_vote_eq = true, + .lower_inverse_ballot = true, }; /** diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index d714dd72f19..2b7c34b2de3 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -5423,6 +5423,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir) subgroup_options.subgroup_size = 1; subgroup_options.lower_vote_trivial = true; } + subgroup_options.lower_inverse_ballot = true; NIR_PASS_V(nir, nir_lower_subgroups, &subgroup_options); } diff --git a/src/gallium/frontends/lavapipe/lvp_pipeline.c b/src/gallium/frontends/lavapipe/lvp_pipeline.c index 7942496bf70..aefd4a2ce66 100644 --- a/src/gallium/frontends/lavapipe/lvp_pipeline.c +++ b/src/gallium/frontends/lavapipe/lvp_pipeline.c @@ -401,6 +401,7 @@ lvp_shader_lower(struct lvp_device *pdevice, struct lvp_pipeline *pipeline, nir_ subgroup_opts.lower_quad = true; subgroup_opts.ballot_components = 1; subgroup_opts.ballot_bit_size = 32; + subgroup_opts.lower_inverse_ballot = true; NIR_PASS_V(nir, nir_lower_subgroups, &subgroup_opts); if (nir->info.stage == MESA_SHADER_FRAGMENT) diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c index d337229eca4..5d21d0f70eb 100644 --- a/src/intel/compiler/brw_nir.c +++ b/src/intel/compiler/brw_nir.c @@ -1000,6 +1000,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir, .lower_relative_shuffle = true, .lower_quad_broadcast_dynamic = true, .lower_elect = true, + .lower_inverse_ballot = true, }; OPT(nir_lower_subgroups, &subgroups_options); diff --git a/src/microsoft/spirv_to_dxil/dxil_spirv_nir.c b/src/microsoft/spirv_to_dxil/dxil_spirv_nir.c index 8b76fd4ebdd..353a9aee729 100644 --- a/src/microsoft/spirv_to_dxil/dxil_spirv_nir.c +++ b/src/microsoft/spirv_to_dxil/dxil_spirv_nir.c @@ -971,6 +971,7 @@ dxil_spirv_nir_passes(nir_shader *nir, .lower_subgroup_masks = true, .lower_to_scalar = true, .lower_relative_shuffle = true, + .lower_inverse_ballot = true, }; if (nir->info.stage != MESA_SHADER_FRAGMENT && nir->info.stage != MESA_SHADER_COMPUTE) diff --git a/src/nouveau/codegen/nv50_ir_from_nir.cpp b/src/nouveau/codegen/nv50_ir_from_nir.cpp index 077a9d7dcf6..9bdc45b9cce 100644 --- a/src/nouveau/codegen/nv50_ir_from_nir.cpp +++ b/src/nouveau/codegen/nv50_ir_from_nir.cpp @@ -3242,6 +3242,7 @@ Converter::run() subgroup_options.ballot_bit_size = 32; subgroup_options.ballot_components = 1; subgroup_options.lower_elect = true; + subgroup_options.lower_inverse_ballot = true; unsigned lower_flrp = (nir->options->lower_flrp16 ? 16 : 0) | (nir->options->lower_flrp32 ? 32 : 0) |