diff --git a/src/gallium/drivers/llvmpipe/lp_state_cs.c b/src/gallium/drivers/llvmpipe/lp_state_cs.c index ac155e8c737..45ca96821c2 100644 --- a/src/gallium/drivers/llvmpipe/lp_state_cs.c +++ b/src/gallium/drivers/llvmpipe/lp_state_cs.c @@ -442,6 +442,111 @@ generate_compute(struct llvmpipe_context *lp, lp_build_name(thread_data_ptr, "thread_data"); lp_build_name(io_ptr, "vertex_io"); + lp_build_nir_prepasses(nir); + struct hash_table *fns = _mesa_pointer_hash_table_create(NULL); + + if (exec_list_length(&nir->functions) > 1) { + LLVMTypeRef call_context_type = lp_build_cs_func_call_context(gallivm, cs_type.length, + variant->jit_cs_context_type, + variant->jit_resources_type); + nir_foreach_function(func, nir) { + if (func->is_entrypoint) + continue; + + LLVMTypeRef args[32]; + int num_args; + + num_args = func->num_params + LP_RESV_FUNC_ARGS; + + args[0] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), cs_type.length); /* mask */ + args[1] = LLVMPointerType(call_context_type, 0); + for (int i = 0; i < func->num_params; i++) { + args[i + LP_RESV_FUNC_ARGS] = LLVMVectorType(LLVMIntTypeInContext(gallivm->context, func->params[i].bit_size), cs_type.length); + if (func->params[i].num_components > 1) + args[i + LP_RESV_FUNC_ARGS] = LLVMArrayType(args[i + LP_RESV_FUNC_ARGS], func->params[i].num_components); + } + + LLVMTypeRef func_type = LLVMFunctionType(LLVMVoidTypeInContext(gallivm->context), + args, num_args, 0); + LLVMValueRef lfunc = LLVMAddFunction(gallivm->module, func->name, func_type); + LLVMSetFunctionCallConv(lfunc, LLVMCCallConv); + + struct lp_build_fn *new_fn = ralloc(fns, struct lp_build_fn); + new_fn->fn_type = func_type; + new_fn->fn = lfunc; + _mesa_hash_table_insert(fns, func, new_fn); + } + + nir_foreach_function(func, nir) { + if (func->is_entrypoint) + continue; + + struct hash_entry *entry = _mesa_hash_table_search(fns, func); + assert(entry); + struct lp_build_fn *new_fn = entry->data; + LLVMValueRef lfunc = new_fn->fn; + block = LLVMAppendBasicBlockInContext(gallivm->context, lfunc, "entry"); + + builder = gallivm->builder; + LLVMPositionBuilderAtEnd(builder, block); + LLVMValueRef mask_param = LLVMGetParam(lfunc, 0); + LLVMValueRef call_context_ptr = LLVMGetParam(lfunc, 1); + LLVMValueRef call_context = LLVMBuildLoad2(builder, call_context_type, call_context_ptr, ""); + struct lp_build_mask_context mask; + struct lp_bld_tgsi_system_values system_values; + + memset(&system_values, 0, sizeof(system_values)); + + lp_build_mask_begin(&mask, gallivm, cs_type, mask_param); + lp_build_mask_check(&mask); + + struct lp_build_tgsi_params params; + memset(¶ms, 0, sizeof(params)); + params.type = cs_type; + params.mask = &mask; + params.fns = fns; + params.current_func = lfunc; + params.context_type = variant->jit_cs_context_type; + params.resources_type = variant->jit_resources_type; + params.call_context_ptr = call_context_ptr; + params.context_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_CONTEXT, ""); + params.resources_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_RESOURCES, ""); + params.shared_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_SHARED, ""); + params.scratch_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_SCRATCH, ""); + system_values.work_dim = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_WORK_DIM, ""); + system_values.thread_id[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_THREAD_ID_0, ""); + system_values.thread_id[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_THREAD_ID_1, ""); + system_values.thread_id[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_THREAD_ID_2, ""); + system_values.block_id[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_ID_0, ""); + system_values.block_id[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_ID_1, ""); + system_values.block_id[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_ID_2, ""); + system_values.grid_size[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_GRID_SIZE_0, ""); + system_values.grid_size[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_GRID_SIZE_1, ""); + system_values.grid_size[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_GRID_SIZE_2, ""); + system_values.block_size[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0, ""); + system_values.block_size[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1, ""); + system_values.block_size[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2, ""); + + params.system_values = &system_values; + + params.consts_ptr = lp_jit_resources_constants(gallivm, + variant->jit_resources_type, + params.resources_ptr); + params.ssbo_ptr = lp_jit_resources_ssbos(gallivm, + variant->jit_resources_type, + params.resources_ptr); + lp_build_nir_soa_func(gallivm, shader->base.ir.nir, + func->impl, + ¶ms, + NULL); + + lp_build_mask_end(&mask); + + LLVMBuildRetVoid(builder); + gallivm_verify_function(gallivm, lfunc); + } + } + block = LLVMAppendBasicBlockInContext(gallivm->context, function, "entry"); builder = gallivm->builder; assert(builder); @@ -750,8 +855,11 @@ generate_compute(struct llvmpipe_context *lp, resources_ptr); params.mesh_iface = &mesh_iface.base; - lp_build_nir_soa(gallivm, shader->base.ir.nir, ¶ms, - NULL); + params.current_func = NULL; + params.fns = fns; + lp_build_nir_soa_func(gallivm, nir, + nir_shader_get_entrypoint(nir), + ¶ms, NULL); if (is_mesh) { LLVMTypeRef i32t = LLVMInt32TypeInContext(gallivm->context); @@ -833,6 +941,7 @@ generate_compute(struct llvmpipe_context *lp, lp_bld_llvm_sampler_soa_destroy(sampler); lp_bld_llvm_image_soa_destroy(image); + _mesa_hash_table_destroy(fns, NULL); gallivm_verify_function(gallivm, coro); gallivm_verify_function(gallivm, function);