diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index 194e142fdf516d1908d15a893364190552e41024..38e9ab96732542d3a926b7da8854abfb17863334 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -8022,7 +8022,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       bld.sopp(aco_opcode::s_sendmsg, bld.m0(ctx->gs_wave_id), -1, sendmsg_gs(true, false, stream));
       break;
    }
-   case nir_intrinsic_set_vertex_count: {
+   case nir_intrinsic_set_vertex_and_primitive_count: {
       /* unused, the HW keeps track of this for us */
       break;
    }
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index 124edd2170d9c1764f4da776b8a16b79a9d1029d..5e4d16bfb192280b97c2debd82ee28e610f839b1 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -554,7 +554,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
 	nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
 
 	if (nir->info.stage == MESA_SHADER_GEOMETRY)
-		nir_lower_gs_intrinsics(nir, true);
+		nir_lower_gs_intrinsics(nir, nir_lower_gs_intrinsics_per_stream);
 
 	static const nir_lower_tex_options tex_options = {
 	  .lower_txp = ~0,
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index d7881d9251dacdea03de1bd238758054353d62f4..ebbed9d0887e849b14dac82889d8881b34dee661 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -4726,7 +4726,12 @@ typedef enum  {
 
 bool nir_lower_to_source_mods(nir_shader *shader, nir_lower_to_source_mods_flags options);
 
-bool nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream);
+typedef enum {
+   nir_lower_gs_intrinsics_per_stream = 1 << 0,
+   nir_lower_gs_intrinsics_count_primitives = 1 << 1,
+} nir_lower_gs_intrinsics_flags;
+
+bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options);
 
 typedef unsigned (*nir_lower_bit_size_callback)(const nir_alu_instr *, void *);
 
diff --git a/src/compiler/nir/nir_gs_count_vertices.c b/src/compiler/nir/nir_gs_count_vertices.c
index 06b9cdf73763e8d72ad296a6f4b88cbcdb52bc77..3520a1f2cb450c7f17e443f67e550116d66fb217 100644
--- a/src/compiler/nir/nir_gs_count_vertices.c
+++ b/src/compiler/nir/nir_gs_count_vertices.c
@@ -38,9 +38,9 @@ as_intrinsic(nir_instr *instr, nir_intrinsic_op op)
 }
 
 static nir_intrinsic_instr *
-as_set_vertex_count(nir_instr *instr)
+as_set_vertex_and_primitive_count(nir_instr *instr)
 {
-   return as_intrinsic(instr, nir_intrinsic_set_vertex_count);
+   return as_intrinsic(instr, nir_intrinsic_set_vertex_and_primitive_count);
 }
 
 /**
@@ -59,14 +59,14 @@ nir_gs_count_vertices(const nir_shader *shader)
       if (!function->impl)
          continue;
 
-      /* set_vertex_count intrinsics only appear in predecessors of the
+      /* set_vertex_and_primitive_count intrinsics only appear in predecessors of the
        * end block.  So we don't need to walk all of them.
        */
       set_foreach(function->impl->end_block->predecessors, entry) {
          nir_block *block = (nir_block *) entry->key;
 
          nir_foreach_instr_reverse(instr, block) {
-            nir_intrinsic_instr *intrin = as_set_vertex_count(instr);
+            nir_intrinsic_instr *intrin = as_set_vertex_and_primitive_count(instr);
             if (!intrin)
                continue;
 
@@ -77,7 +77,7 @@ nir_gs_count_vertices(const nir_shader *shader)
             if (count == -1)
                count = nir_src_as_int(intrin->src[0]);
 
-            /* We've found contradictory set_vertex_count intrinsics.
+            /* We've found contradictory set_vertex_and_primitive_count intrinsics.
              * This can happen if there are early-returns in main() and
              * different paths emit different numbers of vertices.
              */
diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py
index 957eeffcae1505cf6e123722da6a9e077e2a3050..049af6f96c16df5e56b69cd548b3cb9ab4daab56 100644
--- a/src/compiler/nir/nir_intrinsics.py
+++ b/src/compiler/nir/nir_intrinsics.py
@@ -352,7 +352,8 @@ intrinsic("end_primitive", indices=[STREAM_ID])
 # unsigned integer source.
 intrinsic("emit_vertex_with_counter", src_comp=[1], indices=[STREAM_ID])
 intrinsic("end_primitive_with_counter", src_comp=[1], indices=[STREAM_ID])
-intrinsic("set_vertex_count", src_comp=[1])
+# Contains the final total vertex and primitive counts
+intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID])
 
 # Atomic counters
 #
diff --git a/src/compiler/nir/nir_lower_gs_intrinsics.c b/src/compiler/nir/nir_lower_gs_intrinsics.c
index 6e0f4da360e931888c51f65eb1414b1c59516563..07a17de50dbddc59066ae4b30647c93879b19803 100644
--- a/src/compiler/nir/nir_lower_gs_intrinsics.c
+++ b/src/compiler/nir/nir_lower_gs_intrinsics.c
@@ -57,6 +57,9 @@
 struct state {
    nir_builder *builder;
    nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS];
+   nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS];
+   bool per_stream;
+   bool count_prims;
    bool progress;
 };
 
@@ -98,7 +101,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
 
    /* Increment the vertex count by 1 */
    nir_store_var(b, state->vertex_count_vars[stream],
-                 nir_iadd(b, count, nir_imm_int(b, 1)),
+                 nir_iadd_imm(b, count, 1),
                  0x1); /* .x */
 
    nir_pop_if(b, NULL);
@@ -128,6 +131,14 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
    lowered->src[0] = nir_src_for_ssa(count);
    nir_builder_instr_insert(b, &lowered->instr);
 
+   if (state->count_prims) {
+      /* Increment the primitive count by 1 */
+      nir_ssa_def *prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]);
+      nir_store_var(b, state->primitive_count_vars[stream],
+                    nir_iadd_imm(b, prim_cnt, 1),
+                    0x1); /* .x */
+   }
+
    nir_instr_remove(&intrin->instr);
 
    state->progress = true;
@@ -158,11 +169,11 @@ rewrite_intrinsics(nir_block *block, struct state *state)
 }
 
 /**
- * Add a set_vertex_count intrinsic at the end of the program
- * (representing the final vertex count).
+ * Add a set_vertex_and_primitive_count intrinsic at the end of the program
+ * (representing the final total vertex and primitive count).
  */
 static void
-append_set_vertex_count(nir_block *end_block, struct state *state)
+append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
 {
    nir_builder *b = state->builder;
    nir_shader *shader = state->builder->shader;
@@ -174,21 +185,44 @@ append_set_vertex_count(nir_block *end_block, struct state *state)
       nir_block *pred = (nir_block *) entry->key;
       b->cursor = nir_after_block_before_jump(pred);
 
-      nir_ssa_def *count = nir_load_var(b, state->vertex_count_vars[0]);
-
-      nir_intrinsic_instr *set_vertex_count =
-         nir_intrinsic_instr_create(shader, nir_intrinsic_set_vertex_count);
-      set_vertex_count->src[0] = nir_src_for_ssa(count);
-
-      nir_builder_instr_insert(b, &set_vertex_count->instr);
+      for (unsigned stream = 0; stream < NIR_MAX_XFB_STREAMS; ++stream) {
+         /* When it's not per-stream, we only need to write one variable. */
+         if (!state->per_stream && stream != 0)
+            continue;
+         /* When it's per-stream, make sure not to use inactive streams. */
+         if (state->per_stream && !(shader->info.gs.active_stream_mask & (1 << stream)))
+            continue;
+
+         nir_ssa_def *vtx_cnt = nir_load_var(b, state->vertex_count_vars[stream]);
+         nir_ssa_def *prim_cnt;
+
+         if (state->count_prims)
+            prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]);
+         else
+            prim_cnt = nir_ssa_undef(b, 1, 32);
+
+         nir_intrinsic_instr *set_cnt_intrin =
+            nir_intrinsic_instr_create(shader,
+               nir_intrinsic_set_vertex_and_primitive_count);
+
+         nir_intrinsic_set_stream_id(set_cnt_intrin, stream);
+         set_cnt_intrin->src[0] = nir_src_for_ssa(vtx_cnt);
+         set_cnt_intrin->src[1] = nir_src_for_ssa(prim_cnt);
+         nir_builder_instr_insert(b, &set_cnt_intrin->instr);
+      }
    }
 }
 
 bool
-nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
+nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options)
 {
+   bool per_stream = options & nir_lower_gs_intrinsics_per_stream;
+   bool count_primitives = options & nir_lower_gs_intrinsics_count_primitives;
+
    struct state state;
    state.progress = false;
+   state.count_prims = count_primitives;
+   state.per_stream = per_stream;
 
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
    assert(impl);
@@ -197,8 +231,8 @@ nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
    nir_builder_init(&b, impl);
    state.builder = &b;
 
-   /* Create the counter variables */
    b.cursor = nir_before_cf_list(&impl->body);
+
    for (unsigned i = 0; i < NIR_MAX_XFB_STREAMS; i++) {
       if (per_stream && !(shader->info.gs.active_stream_mask & (1 << i)))
          continue;
@@ -208,12 +242,22 @@ nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
             nir_local_variable_create(impl, glsl_uint_type(), "vertex_count");
          /* initialize to 0 */
          nir_store_var(&b, state.vertex_count_vars[i], nir_imm_int(&b, 0), 0x1);
+
+         if (count_primitives) {
+            state.primitive_count_vars[i] =
+               nir_local_variable_create(impl, glsl_uint_type(), "primitive_count");
+            /* initialize to 1 */
+            nir_store_var(&b, state.primitive_count_vars[i], nir_imm_int(&b, 1), 0x1);
+         }
       } else {
-         /* If per_stream is false, we only have one counter which we want to use
-          * for all streams.  Duplicate the counter pointer so all streams use the
-          * same counter.
+         /* If per_stream is false, we only have one counter of each kind which we
+          * want to use for all streams. Duplicate the counter pointers so all
+          * streams use the same counters.
           */
          state.vertex_count_vars[i] = state.vertex_count_vars[0];
+
+         if (count_primitives)
+            state.primitive_count_vars[i] = state.primitive_count_vars[0];
       }
    }
 
@@ -221,8 +265,7 @@ nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
       rewrite_intrinsics(block, &state);
 
    /* This only works because we have a single main() function. */
-   if (!per_stream)
-      append_set_vertex_count(impl->end_block, &state);
+   append_set_vertex_and_primitive_count(impl->end_block, &state);
 
    nir_metadata_preserve(impl, 0);
 
diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp
index 47063178a48a046087255916af856bcad6f83b49..81e7f4ad8d665d68fafef151e5f0bac0b5a79488 100644
--- a/src/intel/compiler/brw_fs_nir.cpp
+++ b/src/intel/compiler/brw_fs_nir.cpp
@@ -3149,7 +3149,7 @@ fs_visitor::nir_emit_gs_intrinsic(const fs_builder &bld,
       emit_gs_end_primitive(instr->src[0]);
       break;
 
-   case nir_intrinsic_set_vertex_count:
+   case nir_intrinsic_set_vertex_and_primitive_count:
       bld.MOV(this->final_gs_vertex_count, get_nir_src(instr->src[0]));
       break;
 
diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c
index 0de36e6798427513bee84c488d6cfd02bdb22701..2b9194b1292a9115dc39552da2da9e4b046a5c67 100644
--- a/src/intel/compiler/brw_nir.c
+++ b/src/intel/compiler/brw_nir.c
@@ -692,7 +692,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir,
    }
 
    if (nir->info.stage == MESA_SHADER_GEOMETRY)
-      OPT(nir_lower_gs_intrinsics, false);
+      OPT(nir_lower_gs_intrinsics, 0);
 
    /* See also brw_nir_trig_workarounds.py */
    if (compiler->precise_trig &&
diff --git a/src/intel/compiler/brw_vec4_gs_nir.cpp b/src/intel/compiler/brw_vec4_gs_nir.cpp
index d4b5582010bcdec416f88a049bf6a1c191c35552..04ee1d9d27f37b6209f9b3589b117cb6841e83f7 100644
--- a/src/intel/compiler/brw_vec4_gs_nir.cpp
+++ b/src/intel/compiler/brw_vec4_gs_nir.cpp
@@ -78,7 +78,7 @@ vec4_gs_visitor::nir_emit_intrinsic(nir_intrinsic_instr *instr)
       gs_end_primitive();
       break;
 
-   case nir_intrinsic_set_vertex_count:
+   case nir_intrinsic_set_vertex_and_primitive_count:
       this->vertex_count =
          retype(get_nir_src(instr->src[0], 1), BRW_REGISTER_TYPE_UD);
       break;