zink: variable shared mem support

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24839>
This commit is contained in:
Karol Herbst 2023-09-19 14:44:26 +02:00 committed by Marge Bot
parent 566112fdf8
commit c5abb7c8d1
5 changed files with 42 additions and 7 deletions

View file

@ -101,6 +101,8 @@ struct ntv_context {
local_group_size_var,
base_vertex_var, base_instance_var, draw_id_var;
SpvId shared_mem_size;
SpvId subgroup_eq_mask_var,
subgroup_ge_mask_var,
subgroup_gt_mask_var,
@ -663,13 +665,25 @@ get_scratch_block(struct ntv_context *ctx, unsigned bit_size)
}
static void
create_shared_block(struct ntv_context *ctx, unsigned shared_size, unsigned bit_size)
create_shared_block(struct ntv_context *ctx, unsigned bit_size)
{
unsigned idx = bit_size >> 4;
SpvId type = spirv_builder_type_uint(&ctx->builder, bit_size);
unsigned block_size = shared_size / (bit_size / 8);
assert(block_size);
SpvId array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, block_size));
SpvId array;
assert(gl_shader_stage_is_compute(ctx->nir->info.stage));
if (ctx->nir->info.cs.has_variable_shared_mem) {
assert(ctx->shared_mem_size);
SpvId const_shared_size = emit_uint_const(ctx, 32, ctx->nir->info.shared_size);
SpvId shared_mem_size = spirv_builder_emit_triop(&ctx->builder, SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpIAdd, const_shared_size, ctx->shared_mem_size);
shared_mem_size = spirv_builder_emit_triop(&ctx->builder, SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpUDiv, shared_mem_size, emit_uint_const(ctx, 32, bit_size / 8));
array = spirv_builder_type_array(&ctx->builder, type, shared_mem_size);
} else {
unsigned block_size = ctx->nir->info.shared_size / (bit_size / 8);
assert(block_size);
array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, block_size));
}
spirv_builder_emit_array_stride(&ctx->builder, array, bit_size / 8);
SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
SpvStorageClassWorkgroup,
@ -686,7 +700,7 @@ get_shared_block(struct ntv_context *ctx, unsigned bit_size)
{
unsigned idx = bit_size >> 4;
if (!ctx->shared_block_var[idx])
create_shared_block(ctx, ctx->nir->info.shared_size, bit_size);
create_shared_block(ctx, bit_size);
if (ctx->sinfo->have_workgroup_memory_explicit_layout) {
spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_workgroup_memory_explicit_layout");
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityWorkgroupMemoryExplicitLayoutKHR);
@ -4591,6 +4605,11 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
spirv_builder_emit_builtin(&ctx.builder, ctx.local_group_size_var, SpvBuiltInWorkgroupSize);
}
}
if (s->info.cs.has_variable_shared_mem) {
ctx.shared_mem_size = spirv_builder_spec_const_uint(&ctx.builder, 32);
spirv_builder_emit_specid(&ctx.builder, ctx.shared_mem_size, ZINK_VARIABLE_SHARED_MEM);
spirv_builder_emit_name(&ctx.builder, ctx.shared_mem_size, "variable_shared_mem");
}
if (s->info.cs.derivative_group) {
SpvCapability caps[] = { 0, SpvCapabilityComputeDerivativeGroupQuadsNV, SpvCapabilityComputeDerivativeGroupLinearNV };
SpvExecutionMode modes[] = { 0, SpvExecutionModeDerivativeGroupQuadsNV, SpvExecutionModeDerivativeGroupLinearNV };

View file

@ -29,6 +29,7 @@
#define ZINK_WORKGROUP_SIZE_X 1
#define ZINK_WORKGROUP_SIZE_Y 2
#define ZINK_WORKGROUP_SIZE_Z 3
#define ZINK_VARIABLE_SHARED_MEM 4
#define ZINK_INLINE_VAL_FLAT_MASK 0
#define ZINK_INLINE_VAL_PV_LAST_VERT 1

View file

@ -457,8 +457,8 @@ zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_pro
stage.pName = "main";
VkSpecializationInfo sinfo = {0};
VkSpecializationMapEntry me[3];
uint32_t data[3];
VkSpecializationMapEntry me[4];
uint32_t data[4];
if (state) {
int i = 0;
@ -475,6 +475,16 @@ zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_pro
}
}
if (comp->has_variable_shared_mem) {
sinfo.mapEntryCount += 1;
sinfo.dataSize += sizeof(uint32_t);
data[i] = state->variable_shared_mem;
me[i].size = sizeof(uint32_t);
me[i].constantID = ZINK_VARIABLE_SHARED_MEM;
me[i].offset = i * sizeof(uint32_t);
i++;
}
if (sinfo.dataSize) {
stage.pSpecializationInfo = &sinfo;
sinfo.pData = data;

View file

@ -1304,6 +1304,10 @@ zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink
ctx->compute_pipeline_state.local_size[i] = info->block[i];
}
}
if (ctx->compute_pipeline_state.variable_shared_mem != info->variable_shared_mem) {
ctx->compute_pipeline_state.dirty = true;
ctx->compute_pipeline_state.variable_shared_mem = info->variable_shared_mem;
}
}
static bool

View file

@ -937,6 +937,7 @@ struct zink_compute_pipeline_state {
uint32_t final_hash;
bool dirty;
uint32_t local_size[3];
uint32_t variable_shared_mem;
uint32_t module_hash;
VkShaderModule module;