From 5e178a07a0708e9007c99dc8577f3b7844377296 Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Tue, 13 Feb 2024 11:25:54 +0100 Subject: [PATCH] llvmpipe: Use full subgroups when possible Fixes computeFullSubgroups on lavapipe. cc: mesa-stable Reviewed-by: Mike Blumenkrantz Part-of: (cherry picked from commit eb3c96d5ed4fe8e57d8d225fa6e740282b510a8f) --- .pick_status.json | 2 +- src/gallium/drivers/llvmpipe/lp_state_cs.c | 141 ++++++++------------- 2 files changed, 51 insertions(+), 92 deletions(-) diff --git a/.pick_status.json b/.pick_status.json index 74ad374d64a..a3b55fb2315 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -1344,7 +1344,7 @@ "description": "llvmpipe: Use full subgroups when possible", "nominated": true, "nomination_type": 0, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/gallium/drivers/llvmpipe/lp_state_cs.c b/src/gallium/drivers/llvmpipe/lp_state_cs.c index c4661ced025..46cc5ffb06f 100644 --- a/src/gallium/drivers/llvmpipe/lp_state_cs.c +++ b/src/gallium/drivers/llvmpipe/lp_state_cs.c @@ -95,7 +95,7 @@ enum { CS_ARG_VERTEX_DATA, CS_ARG_PER_THREAD_DATA, CS_ARG_OUTER_COUNT, - CS_ARG_CORO_X_LOOPS = CS_ARG_OUTER_COUNT, + CS_ARG_CORO_SUBGROUP_COUNT = CS_ARG_OUTER_COUNT, CS_ARG_CORO_PARTIALS, CS_ARG_CORO_BLOCK_X_SIZE, CS_ARG_CORO_BLOCK_Y_SIZE, @@ -374,7 +374,7 @@ generate_compute(struct llvmpipe_context *lp, else arg_types[CS_ARG_VERTEX_DATA] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* mesh shaders only */ arg_types[CS_ARG_PER_THREAD_DATA] = variant->jit_cs_thread_data_ptr_type; /* per thread data */ - arg_types[CS_ARG_CORO_X_LOOPS] = int32_type; /* coro only - num X loops */ + arg_types[CS_ARG_CORO_SUBGROUP_COUNT] = int32_type; /* coro only - subgroup count */ arg_types[CS_ARG_CORO_PARTIALS] = int32_type; /* coro only - partials */ arg_types[CS_ARG_CORO_BLOCK_X_SIZE] = int32_type; /* coro block_x_size */ arg_types[CS_ARG_CORO_BLOCK_Y_SIZE] = int32_type; /* coro block_y_size */ @@ -560,23 +560,24 @@ generate_compute(struct llvmpipe_context *lp, output_array = lp_build_array_alloca(gallivm, output_type, lp_build_const_int32(gallivm, align(MAX2(nir->info.mesh.max_primitives_out, nir->info.mesh.max_vertices_out), 8)), "outputs"); } - struct lp_build_loop_state loop_state[4]; - LLVMValueRef num_x_loop; - LLVMValueRef vec_length = lp_build_const_int32(gallivm, cs_type.length); - num_x_loop = LLVMBuildAdd(gallivm->builder, block_x_size_arg, vec_length, ""); - num_x_loop = LLVMBuildSub(gallivm->builder, num_x_loop, lp_build_const_int32(gallivm, 1), ""); - num_x_loop = LLVMBuildUDiv(gallivm->builder, num_x_loop, vec_length, ""); - LLVMValueRef partials = LLVMBuildURem(gallivm->builder, block_x_size_arg, vec_length, ""); + struct lp_build_loop_state loop_state[2]; - LLVMValueRef coro_num_hdls = LLVMBuildMul(gallivm->builder, num_x_loop, block_y_size_arg, ""); - coro_num_hdls = LLVMBuildMul(gallivm->builder, coro_num_hdls, block_z_size_arg, ""); + LLVMValueRef vec_length = lp_build_const_int32(gallivm, cs_type.length); + + LLVMValueRef invocation_count = LLVMBuildMul(gallivm->builder, block_x_size_arg, block_y_size_arg, ""); + invocation_count = LLVMBuildMul(gallivm->builder, invocation_count, block_z_size_arg, ""); + + LLVMValueRef partials = LLVMBuildURem(gallivm->builder, invocation_count, vec_length, ""); + + LLVMValueRef num_subgroup_loop = LLVMBuildAdd(gallivm->builder, invocation_count, lp_build_const_int32(gallivm, cs_type.length - 1), ""); + num_subgroup_loop = LLVMBuildUDiv(gallivm->builder, num_subgroup_loop, vec_length, ""); /* build a ptr in memory to store all the frames in later. */ LLVMTypeRef hdl_ptr_type = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); LLVMValueRef coro_mem = LLVMBuildAlloca(gallivm->builder, hdl_ptr_type, "coro_mem"); LLVMBuildStore(builder, LLVMConstNull(hdl_ptr_type), coro_mem); - LLVMValueRef coro_hdls = LLVMBuildArrayAlloca(gallivm->builder, hdl_ptr_type, coro_num_hdls, "coro_hdls"); + LLVMValueRef coro_hdls = LLVMBuildArrayAlloca(gallivm->builder, hdl_ptr_type, num_subgroup_loop, "coro_hdls"); unsigned end_coroutine = INT_MAX; @@ -585,22 +586,17 @@ generate_compute(struct llvmpipe_context *lp, * and calls the coroutine main entrypoint on the first pass, but in subsequent * passes it checks if the coroutine has completed and resumes it if not. */ - /* take x_width - round up to type.length width */ - lp_build_loop_begin(&loop_state[3], gallivm, - lp_build_const_int32(gallivm, 0)); /* coroutine reentry loop */ - lp_build_loop_begin(&loop_state[2], gallivm, - lp_build_const_int32(gallivm, 0)); /* z loop */ lp_build_loop_begin(&loop_state[1], gallivm, - lp_build_const_int32(gallivm, 0)); /* y loop */ + lp_build_const_int32(gallivm, 0)); /* coroutine reentry loop */ lp_build_loop_begin(&loop_state[0], gallivm, - lp_build_const_int32(gallivm, 0)); /* x loop */ + lp_build_const_int32(gallivm, 0)); /* subgroup loop */ { LLVMValueRef args[CS_ARG_MAX]; args[CS_ARG_CONTEXT] = context_ptr; args[CS_ARG_RESOURCES] = resources_ptr; - args[CS_ARG_BLOCK_X_SIZE] = loop_state[0].counter; - args[CS_ARG_BLOCK_Y_SIZE] = loop_state[1].counter; - args[CS_ARG_BLOCK_Z_SIZE] = loop_state[2].counter; + args[CS_ARG_BLOCK_X_SIZE] = LLVMGetUndef(int32_type); + args[CS_ARG_BLOCK_Y_SIZE] = LLVMGetUndef(int32_type); + args[CS_ARG_BLOCK_Z_SIZE] = LLVMGetUndef(int32_type); args[CS_ARG_GRID_X] = grid_x_arg; args[CS_ARG_GRID_Y] = grid_y_arg; args[CS_ARG_GRID_Z] = grid_z_arg; @@ -611,34 +607,25 @@ generate_compute(struct llvmpipe_context *lp, args[CS_ARG_DRAW_ID] = draw_id_arg; args[CS_ARG_VERTEX_DATA] = io_ptr; args[CS_ARG_PER_THREAD_DATA] = thread_data_ptr; - args[CS_ARG_CORO_X_LOOPS] = num_x_loop; + args[CS_ARG_CORO_SUBGROUP_COUNT] = num_subgroup_loop; args[CS_ARG_CORO_PARTIALS] = partials; args[CS_ARG_CORO_BLOCK_X_SIZE] = block_x_size_arg; args[CS_ARG_CORO_BLOCK_Y_SIZE] = block_y_size_arg; args[CS_ARG_CORO_BLOCK_Z_SIZE] = block_z_size_arg; - /* idx = (z * (size_x * size_y) + y * size_x + x */ - LLVMValueRef coro_hdl_idx = LLVMBuildMul(gallivm->builder, loop_state[2].counter, - LLVMBuildMul(gallivm->builder, num_x_loop, block_y_size_arg, ""), ""); - coro_hdl_idx = LLVMBuildAdd(gallivm->builder, coro_hdl_idx, - LLVMBuildMul(gallivm->builder, loop_state[1].counter, - num_x_loop, ""), ""); - coro_hdl_idx = LLVMBuildAdd(gallivm->builder, coro_hdl_idx, - loop_state[0].counter, ""); - - args[CS_ARG_CORO_IDX] = coro_hdl_idx; + args[CS_ARG_CORO_IDX] = loop_state[0].counter; args[CS_ARG_CORO_MEM] = coro_mem; if (is_mesh) args[CS_ARG_CORO_OUTPUTS] = output_array; - LLVMValueRef coro_entry = LLVMBuildGEP2(gallivm->builder, hdl_ptr_type, coro_hdls, &coro_hdl_idx, 1, ""); + LLVMValueRef coro_entry = LLVMBuildGEP2(gallivm->builder, hdl_ptr_type, coro_hdls, &loop_state[0].counter, 1, ""); LLVMValueRef coro_hdl = LLVMBuildLoad2(gallivm->builder, hdl_ptr_type, coro_entry, "coro_hdl"); struct lp_build_if_state ifstate; - LLVMValueRef cmp = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, loop_state[3].counter, + LLVMValueRef cmp = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, loop_state[1].counter, lp_build_const_int32(gallivm, 0), ""); /* first time here - call the coroutine function entry point */ lp_build_if(&ifstate, gallivm, cmp); @@ -651,24 +638,18 @@ generate_compute(struct llvmpipe_context *lp, lp_build_if(&ifstate2, gallivm, coro_done); /* if done destroy and force loop exit */ lp_build_coro_destroy(gallivm, coro_hdl); - lp_build_loop_force_set_counter(&loop_state[3], lp_build_const_int32(gallivm, end_coroutine - 1)); + lp_build_loop_force_set_counter(&loop_state[1], lp_build_const_int32(gallivm, end_coroutine - 1)); lp_build_else(&ifstate2); /* otherwise resume the coroutine */ lp_build_coro_resume(gallivm, coro_hdl); lp_build_endif(&ifstate2); lp_build_endif(&ifstate); - lp_build_loop_force_reload_counter(&loop_state[3]); + lp_build_loop_force_reload_counter(&loop_state[1]); } lp_build_loop_end_cond(&loop_state[0], - num_x_loop, + num_subgroup_loop, NULL, LLVMIntUGE); lp_build_loop_end_cond(&loop_state[1], - block_y_size_arg, - NULL, LLVMIntUGE); - lp_build_loop_end_cond(&loop_state[2], - block_z_size_arg, - NULL, LLVMIntUGE); - lp_build_loop_end_cond(&loop_state[3], lp_build_const_int32(gallivm, end_coroutine), NULL, LLVMIntEQ); @@ -680,12 +661,8 @@ generate_compute(struct llvmpipe_context *lp, LLVMBuildRetVoid(builder); /* This is stage (b) - generate the compute shader code inside the coroutine. */ - LLVMValueRef x_size_arg, y_size_arg, z_size_arg; context_ptr = LLVMGetParam(coro, CS_ARG_CONTEXT); resources_ptr = LLVMGetParam(coro, CS_ARG_RESOURCES); - x_size_arg = LLVMGetParam(coro, CS_ARG_BLOCK_X_SIZE); - y_size_arg = LLVMGetParam(coro, CS_ARG_BLOCK_Y_SIZE); - z_size_arg = LLVMGetParam(coro, CS_ARG_BLOCK_Z_SIZE); grid_x_arg = LLVMGetParam(coro, CS_ARG_GRID_X); grid_y_arg = LLVMGetParam(coro, CS_ARG_GRID_Y); grid_z_arg = LLVMGetParam(coro, CS_ARG_GRID_Z); @@ -696,12 +673,12 @@ generate_compute(struct llvmpipe_context *lp, draw_id_arg = LLVMGetParam(coro, CS_ARG_DRAW_ID); io_ptr = LLVMGetParam(coro, CS_ARG_VERTEX_DATA); thread_data_ptr = LLVMGetParam(coro, CS_ARG_PER_THREAD_DATA); - num_x_loop = LLVMGetParam(coro, CS_ARG_CORO_X_LOOPS); + num_subgroup_loop = LLVMGetParam(coro, CS_ARG_CORO_SUBGROUP_COUNT); partials = LLVMGetParam(coro, CS_ARG_CORO_PARTIALS); block_x_size_arg = LLVMGetParam(coro, CS_ARG_CORO_BLOCK_X_SIZE); block_y_size_arg = LLVMGetParam(coro, CS_ARG_CORO_BLOCK_Y_SIZE); block_z_size_arg = LLVMGetParam(coro, CS_ARG_CORO_BLOCK_Z_SIZE); - LLVMValueRef coro_idx = LLVMGetParam(coro, CS_ARG_CORO_IDX); + LLVMValueRef subgroup_id = LLVMGetParam(coro, CS_ARG_CORO_IDX); coro_mem = LLVMGetParam(coro, CS_ARG_CORO_MEM); if (is_mesh) output_array = LLVMGetParam(coro, CS_ARG_CORO_OUTPUTS); @@ -730,27 +707,32 @@ generate_compute(struct llvmpipe_context *lp, variant->jit_cs_thread_data_type, thread_data_ptr); - LLVMValueRef coro_num_hdls = LLVMBuildMul(gallivm->builder, num_x_loop, block_y_size_arg, ""); - coro_num_hdls = LLVMBuildMul(gallivm->builder, coro_num_hdls, block_z_size_arg, ""); - /* these are coroutine entrypoint necessities */ LLVMValueRef coro_id = lp_build_coro_id(gallivm); - LLVMValueRef coro_entry = lp_build_coro_alloc_mem_array(gallivm, coro_mem, coro_idx, coro_num_hdls); + LLVMValueRef coro_entry = lp_build_coro_alloc_mem_array(gallivm, coro_mem, subgroup_id, num_subgroup_loop); LLVMTypeRef mem_ptr_type = LLVMInt8TypeInContext(gallivm->context); LLVMValueRef alloced_ptr = LLVMBuildLoad2(gallivm->builder, hdl_ptr_type, coro_mem, ""); alloced_ptr = LLVMBuildGEP2(gallivm->builder, mem_ptr_type, alloced_ptr, &coro_entry, 1, ""); LLVMValueRef coro_hdl = lp_build_coro_begin(gallivm, coro_id, alloced_ptr); LLVMValueRef has_partials = LLVMBuildICmp(gallivm->builder, LLVMIntNE, partials, lp_build_const_int32(gallivm, 0), ""); - LLVMValueRef tids_x[LP_MAX_VECTOR_LENGTH], tids_y[LP_MAX_VECTOR_LENGTH], tids_z[LP_MAX_VECTOR_LENGTH]; - LLVMValueRef base_val = LLVMBuildMul(gallivm->builder, x_size_arg, vec_length, ""); - for (i = 0; i < cs_type.length; i++) { - tids_x[i] = LLVMBuildAdd(gallivm->builder, base_val, lp_build_const_int32(gallivm, i), ""); - tids_y[i] = y_size_arg; - tids_z[i] = z_size_arg; - } - system_values.thread_id[0] = lp_build_gather_values(gallivm, tids_x, cs_type.length); - system_values.thread_id[1] = lp_build_gather_values(gallivm, tids_y, cs_type.length); - system_values.thread_id[2] = lp_build_gather_values(gallivm, tids_z, cs_type.length); + + struct lp_build_context bld; + lp_build_context_init(&bld, gallivm, lp_uint_type(cs_type)); + + LLVMValueRef base_val = LLVMBuildMul(gallivm->builder, subgroup_id, vec_length, ""); + LLVMValueRef invocation_indices[LP_MAX_VECTOR_LENGTH]; + for (i = 0; i < cs_type.length; i++) + invocation_indices[i] = LLVMBuildAdd(gallivm->builder, base_val, lp_build_const_int32(gallivm, i), ""); + LLVMValueRef invocation_index = lp_build_gather_values(gallivm, invocation_indices, cs_type.length); + + LLVMValueRef block_x_size_vec = lp_build_broadcast_scalar(&bld, block_x_size_arg); + LLVMValueRef block_y_size_vec = lp_build_broadcast_scalar(&bld, block_y_size_arg); + + system_values.thread_id[0] = LLVMBuildURem(gallivm->builder, invocation_index, block_x_size_vec, ""); + system_values.thread_id[1] = LLVMBuildUDiv(gallivm->builder, invocation_index, block_x_size_vec, ""); + system_values.thread_id[1] = LLVMBuildURem(gallivm->builder, system_values.thread_id[1], block_y_size_vec, ""); + system_values.thread_id[2] = LLVMBuildUDiv(gallivm->builder, invocation_index, block_x_size_vec, ""); + system_values.thread_id[2] = LLVMBuildUDiv(gallivm->builder, system_values.thread_id[2], block_y_size_vec, ""); system_values.block_id[0] = grid_x_arg; system_values.block_id[1] = grid_y_arg; @@ -763,38 +745,15 @@ generate_compute(struct llvmpipe_context *lp, system_values.work_dim = work_dim_arg; system_values.draw_id = draw_id_arg; - /* subgroup_id = ((z * block_size_x * block_size_y) + (y * block_size_x) + x) / subgroup_size - * - * this breaks if z or y is zero, so distribute the division to preserve ids - * - * subgroup_id = ((z * block_size_x * block_size_y) / subgroup_size) + ((y * block_size_x) / subgroup_size) + (x / subgroup_size) - * - * except "x" is pre-divided here - * - * subgroup_id = ((z * block_size_x * block_size_y) / subgroup_size) + ((y * block_size_x) / subgroup_size) + x - */ - LLVMValueRef subgroup_id = LLVMBuildUDiv(builder, - LLVMBuildMul(gallivm->builder, z_size_arg, LLVMBuildMul(gallivm->builder, block_x_size_arg, block_y_size_arg, ""), ""), - vec_length, ""); - subgroup_id = LLVMBuildAdd(gallivm->builder, - subgroup_id, - LLVMBuildUDiv(builder, LLVMBuildMul(gallivm->builder, y_size_arg, block_x_size_arg, ""), vec_length, ""), - ""); - subgroup_id = LLVMBuildAdd(gallivm->builder, subgroup_id, x_size_arg, ""); system_values.subgroup_id = subgroup_id; - LLVMValueRef num_subgroups = LLVMBuildUDiv(builder, - LLVMBuildMul(builder, block_x_size_arg, - LLVMBuildMul(builder, block_y_size_arg, block_z_size_arg, ""), ""), - vec_length, ""); - LLVMValueRef subgroup_cmp = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, num_subgroups, lp_build_const_int32(gallivm, 0), ""); - system_values.num_subgroups = LLVMBuildSelect(builder, subgroup_cmp, lp_build_const_int32(gallivm, 1), num_subgroups, ""); + system_values.num_subgroups = num_subgroup_loop; system_values.block_size[0] = block_x_size_arg; system_values.block_size[1] = block_y_size_arg; system_values.block_size[2] = block_z_size_arg; - LLVMValueRef last_x_loop = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, x_size_arg, LLVMBuildSub(gallivm->builder, num_x_loop, lp_build_const_int32(gallivm, 1), ""), ""); - LLVMValueRef use_partial_mask = LLVMBuildAnd(gallivm->builder, last_x_loop, has_partials, ""); + LLVMValueRef last_loop = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, subgroup_id, LLVMBuildSub(gallivm->builder, num_subgroup_loop, lp_build_const_int32(gallivm, 1), ""), ""); + LLVMValueRef use_partial_mask = LLVMBuildAnd(gallivm->builder, last_loop, has_partials, ""); struct lp_build_if_state if_state; LLVMTypeRef mask_type = LLVMVectorType(int32_type, cs_type.length); LLVMValueRef mask_val = lp_build_alloca(gallivm, mask_type, "mask"); @@ -866,7 +825,7 @@ generate_compute(struct llvmpipe_context *lp, lp_int_type(cs_type), 0); struct lp_build_if_state iter0state; - LLVMValueRef is_iter0 = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, coro_idx, + LLVMValueRef is_iter0 = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, subgroup_id, lp_build_const_int32(gallivm, 0), ""); LLVMValueRef vertex_count = LLVMBuildLoad2(gallivm->builder, i32t, mesh_iface.vertex_count, ""); LLVMValueRef prim_count = LLVMBuildLoad2(gallivm->builder, i32t, mesh_iface.prim_count, "");