diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index e2624b268c92a179294572e6defcc6e2de278baa..7af0e5a8fc334416f41bebce52db467478040dc6 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -6895,6 +6895,7 @@ void ngg_visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr *
 }
 
 void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream);
+void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr);
 
 void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_instr *instr)
 {
@@ -6908,7 +6909,7 @@ void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_i
       ngg_gs_clear_primflags(ctx, vtx_cnt, stream);
    }
 
-   /* TODO: also take the primitive count into use */
+   ngg_gs_write_shader_query(ctx, instr);
 }
 
 void visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr *instr)
@@ -11165,6 +11166,66 @@ void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream)
    end_loop(ctx, &lc);
 }
 
+void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+   /* Each subgroup uses a single GDS atomic to collect the total number of primitives.
+    * TODO: Consider using primitive compaction at the end instead.
+    */
+
+   unsigned total_vtx_per_prim = gs_outprim_vertices(ctx->shader->info.gs.output_primitive);
+   if_context ic_shader_query;
+   Builder bld(ctx->program, ctx->block);
+
+   Temp shader_query = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), get_arg(ctx, ctx->args->ngg_gs_state), Operand(0u));
+   begin_uniform_if_then(ctx, &ic_shader_query, shader_query);
+   bld.reset(ctx->block);
+
+   Temp gs_vtx_cnt = get_ssa_temp(ctx, instr->src[0].ssa);
+   Temp gs_prm_cnt = get_ssa_temp(ctx, instr->src[1].ssa);
+   Temp sg_prm_cnt;
+
+   /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
+    * GS emits points, line strips or triangle strips.
+    * Real primitives are points, lines or triangles.
+    */
+   if (nir_src_is_const(instr->src[0]) && nir_src_is_const(instr->src[1])) {
+      unsigned gs_vtx_cnt = nir_src_as_uint(instr->src[0]);
+      unsigned gs_prm_cnt = nir_src_as_uint(instr->src[1]);
+      Temp prm_cnt = bld.copy(bld.def(s1), Operand(gs_vtx_cnt - gs_prm_cnt * (total_vtx_per_prim - 1u)));
+      Temp thread_cnt = bld.sop1(Builder::s_bcnt1_i32, bld.def(s1), bld.def(s1, scc), Operand(exec, bld.lm));
+      sg_prm_cnt = bld.sop2(aco_opcode::s_mul_i32, bld.def(s1), prm_cnt, thread_cnt);
+   } else {
+      Temp prm_cnt = gs_vtx_cnt;
+      if (total_vtx_per_prim > 1)
+         prm_cnt = bld.vop3(aco_opcode::v_mad_i32_i24, bld.def(v1), gs_prm_cnt, Operand(-1u * (total_vtx_per_prim - 1)), gs_vtx_cnt);
+
+      /* Reduction calculates the primitive count for the entire subgroup. */
+      sg_prm_cnt = bld.tmp(s1);
+      aco_ptr<Pseudo_reduction_instruction> red_instr
+         {create_reduction_instr(ctx, aco_opcode::p_reduce, ReduceOp::iadd32, Definition(sg_prm_cnt), prm_cnt)};
+      red_instr->cluster_size = ctx->program->wave_size;
+      bld.insert(std::move(red_instr));
+   }
+
+   Temp first_lane = bld.sop1(Builder::s_ff1_i32, bld.def(s1), Operand(exec, bld.lm));
+   Temp is_first_lane = bld.sop2(Builder::s_lshl, bld.def(bld.lm), bld.def(s1, scc),
+                                 Operand(1u, ctx->program->wave_size == 64), first_lane);
+
+   if_context ic_last_lane;
+   begin_divergent_if_then(ctx, &ic_last_lane, is_first_lane);
+   bld.reset(ctx->block);
+
+   Temp gds_addr = bld.copy(bld.def(v1), Operand(0u));
+   Operand m = bld.m0((Temp)bld.sopk(aco_opcode::s_movk_i32, bld.def(s1, m0), 0x100));
+   bld.ds(aco_opcode::ds_add_u32, gds_addr, as_vgpr(ctx, sg_prm_cnt), m, 0u, 0u, true);
+
+   begin_divergent_if_else(ctx, &ic_last_lane);
+   end_divergent_if(ctx, &ic_last_lane);
+
+   begin_uniform_if_else(ctx, &ic_shader_query);
+   end_uniform_if(ctx, &ic_shader_query);
+}
+
 Temp ngg_gs_load_prim_flag_0(isel_context *ctx, Temp tid_in_tg, Temp max_vtxcnt, Temp vertex_lds_addr)
 {
    if_context ic;