diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index d6071e065b4b0d2b8fa1770dd5e2afcc1e922b6e..5242aff8c7ecb889e3bcd871ba82af3dc4725365 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -10692,7 +10692,21 @@ bool ngg_early_prim_export(isel_context *ctx) return true; } -void ngg_emit_sendmsg_gs_alloc_req(isel_context *ctx) +Temp ngg_max_vertex_count(isel_context *ctx) +{ + Builder bld(ctx->program, ctx->block); + return bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), + get_arg(ctx, ctx->args->gs_tg_info), Operand(12u | (9u << 16u))); +} + +Temp ngg_max_primitive_count(isel_context *ctx) +{ + Builder bld(ctx->program, ctx->block); + return bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), + get_arg(ctx, ctx->args->gs_tg_info), Operand(22u | (9u << 16u))); +} + +void ngg_emit_sendmsg_gs_alloc_req(isel_context *ctx, Temp vtx_cnt = Temp(), Temp prm_cnt = Temp()) { Builder bld(ctx->program, ctx->block); @@ -10712,12 +10726,20 @@ void ngg_emit_sendmsg_gs_alloc_req(isel_context *ctx) begin_uniform_if_else(ctx, &ic); bld.reset(ctx->block); + /* VS/TES: we infer the vertex and primitive count from arguments + * GS: the caller needs to supply them + */ + assert(ctx->shader->info.stage == MESA_SHADER_GEOMETRY + ? (vtx_cnt.id() && prm_cnt.id()) + : (!vtx_cnt.id() && !prm_cnt.id())); + /* Number of vertices output by VS/TES */ - Temp vtx_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), - get_arg(ctx, ctx->args->gs_tg_info), Operand(12u | (9u << 16u))); + if (vtx_cnt.id() == 0) + vtx_cnt = ngg_max_vertex_count(ctx); + /* Number of primitives output by VS/TES */ - Temp prm_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), - get_arg(ctx, ctx->args->gs_tg_info), Operand(22u | (9u << 16u))); + if (prm_cnt.id() == 0) + prm_cnt = ngg_max_primitive_count(ctx); /* Put the number of vertices and primitives into m0 for the GS_ALLOC_REQ */ Temp tmp = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), prm_cnt, Operand(12u));