radv/rt: pass radv_ray_tracing_pipeline to RT shader creation

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22096>
This commit is contained in:
Daniel Schürmann 2023-06-01 12:20:31 +02:00 committed by Marge Bot
parent 8fb7df92c0
commit 62b4380acb
4 changed files with 34 additions and 40 deletions

View file

@ -467,11 +467,7 @@ radv_rt_pipeline_compile(struct radv_ray_tracing_pipeline *pipeline,
if (result != VK_SUCCESS)
return result;
VkRayTracingPipelineCreateInfoKHR local_create_info =
radv_create_merged_rt_create_info(pCreateInfo);
rt_stage.internal_nir = create_rt_shader(device, &local_create_info, pipeline->stages,
pipeline->groups, pipeline_key);
rt_stage.internal_nir = create_rt_shader(device, pipeline, pCreateInfo, pipeline_key);
/* Compile SPIR-V shader to NIR. */
rt_stage.nir =

View file

@ -2144,6 +2144,7 @@ struct radv_event {
#define RADV_HASH_SHADER_NGG_STREAMOUT (1 << 20)
struct radv_pipeline_key;
struct radv_ray_tracing_group;
void radv_pipeline_stage_init(const VkPipelineShaderStageCreateInfo *sinfo,
struct radv_pipeline_stage *out_stage, gl_shader_stage stage);

View file

@ -1186,28 +1186,25 @@ init_traversal_vars(nir_builder *b)
struct traversal_data {
struct radv_device *device;
const VkRayTracingPipelineCreateInfoKHR *createInfo;
struct rt_variables *vars;
struct rt_traversal_vars *trav_vars;
nir_variable *barycentrics;
struct radv_ray_tracing_group *groups;
struct radv_ray_tracing_stage *stages;
struct radv_ray_tracing_pipeline *pipeline;
const struct radv_pipeline_key *key;
};
static void
visit_any_hit_shaders(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
struct traversal_data *data, struct rt_variables *vars)
visit_any_hit_shaders(struct radv_device *device, nir_builder *b, struct traversal_data *data,
struct rt_variables *vars)
{
nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
if (!(vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
struct radv_ray_tracing_group *group = &data->groups[i];
for (unsigned i = 0; i < data->pipeline->group_count; ++i) {
struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
uint32_t shader_id = VK_SHADER_UNUSED_KHR;
switch (group->type) {
@ -1223,18 +1220,19 @@ visit_any_hit_shaders(struct radv_device *device,
/* Avoid emitting stages with the same shaders/handles multiple times. */
bool is_dup = false;
for (unsigned j = 0; j < i; ++j)
if (data->groups[j].handle.any_hit_index == data->groups[i].handle.any_hit_index)
if (data->pipeline->groups[j].handle.any_hit_index ==
data->pipeline->groups[i].handle.any_hit_index)
is_dup = true;
if (is_dup)
continue;
nir_shader *nir_stage =
radv_pipeline_cache_handle_to_nir(device, data->stages[shader_id].shader);
radv_pipeline_cache_handle_to_nir(device, data->pipeline->stages[shader_id].shader);
assert(nir_stage);
insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->groups[i].handle.any_hit_index,
shader_id, data->stages);
insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->pipeline->groups[i].handle.any_hit_index,
shader_id, data->pipeline->stages);
ralloc_free(nir_stage);
}
@ -1279,7 +1277,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
visit_any_hit_shaders(data->device, data->createInfo, b, args->data, &inner_vars);
visit_any_hit_shaders(data->device, b, args->data, &inner_vars);
nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
{
@ -1341,8 +1339,8 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
if (!(data->vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR))
nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
for (unsigned i = 0; i < data->createInfo->groupCount; ++i) {
struct radv_ray_tracing_group *group = &data->groups[i];
for (unsigned i = 0; i < data->pipeline->group_count; ++i) {
struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
uint32_t shader_id = VK_SHADER_UNUSED_KHR;
uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
@ -1360,31 +1358,33 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
/* Avoid emitting stages with the same shaders/handles multiple times. */
bool is_dup = false;
for (unsigned j = 0; j < i; ++j)
if (data->groups[j].handle.intersection_index == data->groups[i].handle.intersection_index)
if (data->pipeline->groups[j].handle.intersection_index ==
data->pipeline->groups[i].handle.intersection_index)
is_dup = true;
if (is_dup)
continue;
nir_shader *nir_stage =
radv_pipeline_cache_handle_to_nir(data->device, data->stages[shader_id].shader);
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[shader_id].shader);
assert(nir_stage);
nir_shader *any_hit_stage = NULL;
if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
any_hit_stage =
radv_pipeline_cache_handle_to_nir(data->device, data->stages[any_hit_shader_id].shader);
any_hit_stage = radv_pipeline_cache_handle_to_nir(
data->device, data->pipeline->stages[any_hit_shader_id].shader);
assert(any_hit_stage);
/* reserve stack size for any_hit before it is inlined */
data->stages[any_hit_shader_id].stack_size = any_hit_stage->scratch_size;
data->pipeline->stages[any_hit_shader_id].stack_size = any_hit_stage->scratch_size;
nir_lower_intersection_shader(nir_stage, any_hit_stage);
ralloc_free(any_hit_stage);
}
insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0,
data->groups[i].handle.intersection_index, shader_id, data->stages);
data->pipeline->groups[i].handle.intersection_index, shader_id,
data->pipeline->stages);
ralloc_free(nir_stage);
}
@ -1428,9 +1428,9 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_trave
}
static nir_shader *
build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_stage *stages,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_ray_tracing_group *groups, const struct radv_pipeline_key *key)
radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct radv_pipeline_key *key)
{
/* Create the traversal shader as an intersection shader to prevent validation failures due to
* invalid variable modes.*/
@ -1517,12 +1517,10 @@ build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_stage
struct traversal_data data = {
.device = device,
.createInfo = pCreateInfo,
.vars = &vars,
.trav_vars = &trav_vars,
.barycentrics = barycentrics,
.groups = groups,
.stages = stages,
.pipeline = pipeline,
.key = key,
};
@ -1626,8 +1624,8 @@ move_rt_instructions(nir_shader *shader)
}
nir_shader *
create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_ray_tracing_stage *stages, struct radv_ray_tracing_group *groups,
create_rt_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct radv_pipeline_key *key)
{
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_RAYGEN, "rt_combined");
@ -1644,12 +1642,14 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
nir_ssa_def *idx = nir_load_var(&b, vars.idx);
/* Insert traversal shader */
nir_shader *traversal = build_traversal_shader(device, stages, pCreateInfo, groups, key);
nir_shader *traversal = radv_build_traversal_shader(device, pipeline, pCreateInfo, key);
b.shader->info.shared_size = MAX2(b.shader->info.shared_size, traversal->info.shared_size);
assert(b.shader->info.shared_size <= 32768);
insert_rt_case(&b, traversal, &vars, idx, 0, 1, -1u, NULL);
ralloc_free(traversal);
struct radv_ray_tracing_group *groups = pipeline->groups;
struct radv_ray_tracing_stage *stages = pipeline->stages;
unsigned call_idx_base = 1;
for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
unsigned stage_idx = groups[i].recursive_shader;

View file

@ -43,8 +43,7 @@
struct radv_physical_device;
struct radv_device;
struct radv_pipeline;
struct radv_ray_tracing_stage;
struct radv_ray_tracing_group;
struct radv_ray_tracing_pipeline;
struct radv_pipeline_key;
struct radv_shader_args;
struct radv_vs_input_state;
@ -787,10 +786,8 @@ bool radv_consider_culling(const struct radv_physical_device *pdevice, struct ni
void radv_get_nir_options(struct radv_physical_device *device);
nir_shader *create_rt_shader(struct radv_device *device,
nir_shader *create_rt_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_ray_tracing_stage *stages,
struct radv_ray_tracing_group *groups,
const struct radv_pipeline_key *key);
#endif