Commit 75dbb404 authored by Timur Kristóf's avatar Timur Kristóf
Browse files

ac/nir: Remove byte permute from prefix sum of the repack sequence.



The byte-permute instruction v_perm_b32 is not exposed by older
LLVM releases (only available on LLVM 13 and later), therefore a new
sequence is needed which we can use with these LLVM versions too.

The prefix sum is replaced by two alternatives:

1. For GPUs that support v_dot, we shift 0x01 to the wanted byte
positions and then use v_dot to sum the results.

2. For older GPUs (Navi 10), we simply shift out the unwanted bytes
and use v_sad_u8 to produce the sum.
Signed-off-by: Timur Kristóf's avatarTimur Kristóf <timur.kristof@gmail.com>
Acked-by: default avatarMarek Olšák <marek.olsak@amd.com>
Part-of: <!12786>
parent 966cff9c
Pipeline #404983 waiting for manual action with stages
......@@ -133,6 +133,70 @@ typedef struct {
nir_ssa_def *repacked_invocation_index;
} wg_repack_result;
/**
* Computes a horizontal sum of 8-bit packed values loaded from LDS.
*
* Each lane N will sum packed bytes 0 to N-1.
* We only care about the results from up to wave_id+1 lanes.
* (Other lanes are not deactivated but their calculation is not used.)
*/
static nir_ssa_def *
summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords)
{
/* We'll use shift to filter out the bytes not needed by the current lane.
*
* Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
* However, two shifts are needed because one can't go all the way,
* so the shift amount is half that (and in bits).
*
* When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
* This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
* therefore v_dot can get rid of the unneeded values.
* This sequence is preferable because it better hides the latency of the LDS.
*
* If the v_dot instruction can't be used, we left-shift the packed bytes.
* This will shift out the unneeded bytes and shift in zeroes instead,
* then we sum them using v_sad_u8.
*/
nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
nir_ssa_def *shift = nir_iadd_imm_nuw(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
bool use_dot = b->shader->options->has_dot_4x8;
if (num_lds_dwords == 1) {
nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
nir_ssa_def *packed = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
/* Horizontally add the packed bytes. */
if (use_dot) {
return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
} else {
nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
}
} else if (num_lds_dwords == 2) {
nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
/* Horizontally add the packed bytes. */
if (use_dot) {
nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
} else {
nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
}
} else {
unreachable("Unimplemented NGG wave count");
}
}
/**
* Repacks invocations in the current workgroup to eliminate gaps between them.
*
......@@ -208,41 +272,7 @@ repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
*/
nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);
/* sel = 0x01010101 * lane_id + 0x03020100 */
nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), lane_id, nir_imm_int(b, 0));
nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100);
nir_ssa_def *sum = NULL;
if (num_lds_dwords == 1) {
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
/* Use byte-permute to filter out the bytes not needed by the current lane. */
nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, nir_imm_int(b, 0), sel);
/* Horizontally add the packed bytes. */
sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 0));
} else if (num_lds_dwords == 2) {
/* Create selectors for the byte-permutes below. */
nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4));
nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4));
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
/* Use byte-permute to filter out the bytes not needed by the current lane. */
nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, packed_dw0, nir_imm_int(b, 0), dw0_selector);
nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, packed_dw1, nir_imm_int(b, 0), dw1_selector);
/* Horizontally add the packed bytes. */
sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), nir_imm_int(b, 0));
sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum);
} else {
unreachable("Unimplemented NGG wave count");
}
nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);
nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment