gallivm: Consider the initial mask when terminating loops

Partial subgroups can lead to infinite loops otherwise.

cc: mesa-stable

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27603>
(cherry picked from commit 4d7beb22fae3fe10aed86066ee9d2d9536625a72)
This commit is contained in:
Konstantin Seurer 2024-02-13 20:37:42 +01:00 committed by Eric Engestrom
parent 5e178a07a0
commit 6aa24ea086
5 changed files with 31 additions and 22 deletions

View file

@ -1334,7 +1334,7 @@
"description": "gallivm: Consider the initial mask when terminating loops", "description": "gallivm: Consider the initial mask when terminating loops",
"nominated": true, "nominated": true,
"nomination_type": 0, "nomination_type": 0,
"resolution": 0, "resolution": 1,
"main_sha": null, "main_sha": null,
"because_sha": null, "because_sha": null,
"notes": null "notes": null

View file

@ -27,6 +27,7 @@
**************************************************************************/ **************************************************************************/
#include "util/u_memory.h" #include "util/u_memory.h"
#include "lp_bld_const.h"
#include "lp_bld_type.h" #include "lp_bld_type.h"
#include "lp_bld_init.h" #include "lp_bld_init.h"
#include "lp_bld_flow.h" #include "lp_bld_flow.h"
@ -271,18 +272,17 @@ void lp_exec_bgnloop(struct lp_exec_mask *mask, bool load)
} }
void lp_exec_endloop(struct gallivm_state *gallivm, void lp_exec_endloop(struct gallivm_state *gallivm,
struct lp_exec_mask *mask) struct lp_exec_mask *exec_mask,
struct lp_build_mask_context *mask)
{ {
LLVMBuilderRef builder = mask->bld->gallivm->builder; LLVMBuilderRef builder = exec_mask->bld->gallivm->builder;
struct function_ctx *ctx = func_ctx(mask); struct function_ctx *ctx = func_ctx(exec_mask);
LLVMBasicBlockRef endloop; LLVMBasicBlockRef endloop;
LLVMTypeRef int_type = LLVMInt32TypeInContext(mask->bld->gallivm->context); LLVMTypeRef int_type = LLVMInt32TypeInContext(exec_mask->bld->gallivm->context);
LLVMTypeRef reg_type = LLVMIntTypeInContext(gallivm->context, LLVMTypeRef mask_type = LLVMIntTypeInContext(exec_mask->bld->gallivm->context, exec_mask->bld->type.length);
mask->bld->type.width *
mask->bld->type.length);
LLVMValueRef i1cond, i2cond, icond, limiter; LLVMValueRef i1cond, i2cond, icond, limiter;
assert(mask->break_mask); assert(exec_mask->break_mask);
assert(ctx->loop_stack_size); assert(ctx->loop_stack_size);
if (ctx->loop_stack_size > LP_MAX_TGSI_NESTING) { if (ctx->loop_stack_size > LP_MAX_TGSI_NESTING) {
@ -294,14 +294,14 @@ void lp_exec_endloop(struct gallivm_state *gallivm,
/* /*
* Restore the cont_mask, but don't pop * Restore the cont_mask, but don't pop
*/ */
mask->cont_mask = ctx->loop_stack[ctx->loop_stack_size - 1].cont_mask; exec_mask->cont_mask = ctx->loop_stack[ctx->loop_stack_size - 1].cont_mask;
lp_exec_mask_update(mask); lp_exec_mask_update(exec_mask);
/* /*
* Unlike the continue mask, the break_mask must be preserved across loop * Unlike the continue mask, the break_mask must be preserved across loop
* iterations * iterations
*/ */
LLVMBuildStore(builder, mask->break_mask, ctx->break_var); LLVMBuildStore(builder, exec_mask->break_mask, ctx->break_var);
/* Decrement the loop limiter */ /* Decrement the loop limiter */
limiter = LLVMBuildLoad2(builder, int_type, ctx->loop_limiter, ""); limiter = LLVMBuildLoad2(builder, int_type, ctx->loop_limiter, "");
@ -314,12 +314,18 @@ void lp_exec_endloop(struct gallivm_state *gallivm,
LLVMBuildStore(builder, limiter, ctx->loop_limiter); LLVMBuildStore(builder, limiter, ctx->loop_limiter);
/* i1cond = (mask != 0) */ LLVMValueRef end_mask = exec_mask->exec_mask;
if (mask)
end_mask = LLVMBuildAnd(builder, exec_mask->exec_mask, lp_build_mask_value(mask), "");
end_mask = LLVMBuildICmp(builder, LLVMIntNE, end_mask, lp_build_zero(gallivm, exec_mask->bld->type), "");
end_mask = LLVMBuildBitCast(builder, end_mask, mask_type, "");
/* i1cond = (end_mask != 0) */
i1cond = LLVMBuildICmp( i1cond = LLVMBuildICmp(
builder, builder,
LLVMIntNE, LLVMIntNE,
LLVMBuildBitCast(builder, mask->exec_mask, reg_type, ""), end_mask,
LLVMConstNull(reg_type), "i1cond"); LLVMConstNull(mask_type), "i1cond");
/* i2cond = (looplimiter > 0) */ /* i2cond = (looplimiter > 0) */
i2cond = LLVMBuildICmp( i2cond = LLVMBuildICmp(
@ -331,7 +337,7 @@ void lp_exec_endloop(struct gallivm_state *gallivm,
/* if( i1cond && i2cond ) */ /* if( i1cond && i2cond ) */
icond = LLVMBuildAnd(builder, i1cond, i2cond, ""); icond = LLVMBuildAnd(builder, i1cond, i2cond, "");
endloop = lp_build_insert_new_block(mask->bld->gallivm, "endloop"); endloop = lp_build_insert_new_block(exec_mask->bld->gallivm, "endloop");
LLVMBuildCondBr(builder, LLVMBuildCondBr(builder,
icond, ctx->loop_block, endloop); icond, ctx->loop_block, endloop);
@ -341,14 +347,14 @@ void lp_exec_endloop(struct gallivm_state *gallivm,
assert(ctx->loop_stack_size); assert(ctx->loop_stack_size);
--ctx->loop_stack_size; --ctx->loop_stack_size;
--ctx->bgnloop_stack_size; --ctx->bgnloop_stack_size;
mask->cont_mask = ctx->loop_stack[ctx->loop_stack_size].cont_mask; exec_mask->cont_mask = ctx->loop_stack[ctx->loop_stack_size].cont_mask;
mask->break_mask = ctx->loop_stack[ctx->loop_stack_size].break_mask; exec_mask->break_mask = ctx->loop_stack[ctx->loop_stack_size].break_mask;
ctx->loop_block = ctx->loop_stack[ctx->loop_stack_size].loop_block; ctx->loop_block = ctx->loop_stack[ctx->loop_stack_size].loop_block;
ctx->break_var = ctx->loop_stack[ctx->loop_stack_size].break_var; ctx->break_var = ctx->loop_stack[ctx->loop_stack_size].break_var;
ctx->break_type = ctx->break_type_stack[ctx->loop_stack_size + ctx->break_type = ctx->break_type_stack[ctx->loop_stack_size +
ctx->switch_stack_size]; ctx->switch_stack_size];
lp_exec_mask_update(mask); lp_exec_mask_update(exec_mask);
} }
void lp_exec_mask_cond_push(struct lp_exec_mask *mask, void lp_exec_mask_cond_push(struct lp_exec_mask *mask,

View file

@ -101,6 +101,8 @@ struct lp_exec_mask {
int function_stack_size; int function_stack_size;
}; };
struct lp_build_mask_context;
void lp_exec_mask_function_init(struct lp_exec_mask *mask, int function_idx); void lp_exec_mask_function_init(struct lp_exec_mask *mask, int function_idx);
void lp_exec_mask_init(struct lp_exec_mask *mask, struct lp_build_context *bld); void lp_exec_mask_init(struct lp_exec_mask *mask, struct lp_build_context *bld);
void lp_exec_mask_fini(struct lp_exec_mask *mask); void lp_exec_mask_fini(struct lp_exec_mask *mask);
@ -112,7 +114,8 @@ void lp_exec_mask_update(struct lp_exec_mask *mask);
void lp_exec_bgnloop_post_phi(struct lp_exec_mask *mask); void lp_exec_bgnloop_post_phi(struct lp_exec_mask *mask);
void lp_exec_bgnloop(struct lp_exec_mask *mask, bool load_mask); void lp_exec_bgnloop(struct lp_exec_mask *mask, bool load_mask);
void lp_exec_endloop(struct gallivm_state *gallivm, void lp_exec_endloop(struct gallivm_state *gallivm,
struct lp_exec_mask *mask); struct lp_exec_mask *exec_mask,
struct lp_build_mask_context *mask);
void lp_exec_mask_cond_push(struct lp_exec_mask *mask, void lp_exec_mask_cond_push(struct lp_exec_mask *mask,
LLVMValueRef val); LLVMValueRef val);
void lp_exec_mask_cond_invert(struct lp_exec_mask *mask); void lp_exec_mask_cond_invert(struct lp_exec_mask *mask);

View file

@ -2024,7 +2024,7 @@ static void bgnloop(struct lp_build_nir_context *bld_base)
static void endloop(struct lp_build_nir_context *bld_base) static void endloop(struct lp_build_nir_context *bld_base)
{ {
struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
lp_exec_endloop(bld_base->base.gallivm, &bld->exec_mask); lp_exec_endloop(bld_base->base.gallivm, &bld->exec_mask, bld->mask);
} }
static void if_cond(struct lp_build_nir_context *bld_base, LLVMValueRef cond) static void if_cond(struct lp_build_nir_context *bld_base, LLVMValueRef cond)

View file

@ -4268,7 +4268,7 @@ endloop_emit(
{ {
struct lp_build_tgsi_soa_context * bld = lp_soa_context(bld_base); struct lp_build_tgsi_soa_context * bld = lp_soa_context(bld_base);
lp_exec_endloop(bld_base->base.gallivm, &bld->exec_mask); lp_exec_endloop(bld_base->base.gallivm, &bld->exec_mask, bld->mask);
} }
static void static void