Commit a8ec4082 authored by Rob Clark's avatar Rob Clark 💬 Committed by Karol Herbst
Browse files

nir+vtn: vec8+vec16 support



This introduces new vec8 and vec16 instructions (which are the only
instructions taking more than 4 sources), in order to construct 8 and 16
component vectors.

In order to avoid fixing up the non-autogenerated nir_build_alu() sites
and making them pass 16 src args for the benefit of the two instructions
that take more than 4 srcs (ie vec8 and vec16), nir_build_alu() is has
nir_build_alu_tail() split out and re-used by nir_build_alu2() (which is
used for the > 4 src args case).

v2 (Karol Herbst):
  use nir_build_alu2 for vec8 and vec16
  use python's array multiplication syntax
  add nir_op_vec helper
  simplify nir_vec
  nir_build_alu_tail -> nir_builder_alu_instr_finish_and_insert
  use nir_build_alu for opcodes with <= 4 sources
v3 (Karol Herbst):
  fix nir_serialize
v4 (Dave Airlie):
  fix serialization of glsl_type
  handle vec8/16 in lowering of bools
v5 (Karol Herbst):
  fix load store vectorizer
Signed-off-by: Karol Herbst's avatarKarol Herbst <kherbst@redhat.com>
Reviewed-by: default avatarDave Airlie <airlied@redhat.com>
parent b35e583c
......@@ -2630,9 +2630,13 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type)
case GLSL_TYPE_INT64:
case GLSL_TYPE_BOOL:
encoded.basic.interface_row_major = type->interface_row_major;
assert(type->vector_elements < 8);
assert(type->matrix_columns < 8);
encoded.basic.vector_elements = type->vector_elements;
if (type->vector_elements <= 4)
encoded.basic.vector_elements = type->vector_elements;
else if (type->vector_elements == 8)
encoded.basic.vector_elements = 5;
else if (type->vector_elements == 16)
encoded.basic.vector_elements = 6;
encoded.basic.matrix_columns = type->matrix_columns;
encoded.basic.explicit_stride = MIN2(type->explicit_stride, 0xfffff);
blob_write_uint32(blob, encoded.u32);
......@@ -2741,6 +2745,11 @@ decode_type_from_blob(struct blob_reader *blob)
unsigned explicit_stride = encoded.basic.explicit_stride;
if (explicit_stride == 0xfffff)
explicit_stride = blob_read_uint32(blob);
uint32_t vector_elements = encoded.basic.vector_elements;
if (vector_elements == 5)
vector_elements = 8;
else if (vector_elements == 6)
vector_elements = 16;
return glsl_type::get_instance(base_type, encoded.basic.vector_elements,
encoded.basic.matrix_columns,
explicit_stride,
......
......@@ -58,10 +58,19 @@ extern "C" {
#define NIR_FALSE 0u
#define NIR_TRUE (~0u)
#define NIR_MAX_VEC_COMPONENTS 4
#define NIR_MAX_VEC_COMPONENTS 16
#define NIR_MAX_MATRIX_COLUMNS 4
#define NIR_STREAM_PACKED (1 << 8)
typedef uint8_t nir_component_mask_t;
typedef uint16_t nir_component_mask_t;
static inline bool
nir_num_components_valid(unsigned num_components)
{
return (num_components >= 1 &&
num_components <= 4) ||
num_components == 8 ||
num_components == 16;
}
/** Defines a cast function
*
......@@ -1030,6 +1039,8 @@ nir_op_vec(unsigned components)
case 2: return nir_op_vec2;
case 3: return nir_op_vec3;
case 4: return nir_op_vec4;
case 8: return nir_op_vec8;
case 16: return nir_op_vec16;
default: unreachable("bad component count");
}
}
......
......@@ -874,7 +874,7 @@ nir_ssa_for_src(nir_builder *build, nir_src src, int num_components)
static inline nir_ssa_def *
nir_ssa_for_alu_src(nir_builder *build, nir_alu_instr *instr, unsigned srcn)
{
static uint8_t trivial_swizzle[] = { 0, 1, 2, 3 };
static uint8_t trivial_swizzle[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
STATIC_ASSERT(ARRAY_SIZE(trivial_swizzle) == NIR_MAX_VEC_COMPONENTS);
nir_alu_src *src = &instr->src[srcn];
......
......@@ -31,14 +31,22 @@ def src_decl_list(num_srcs):
return ', '.join('nir_ssa_def *src' + str(i) for i in range(num_srcs))
def src_list(num_srcs):
return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4))
if num_srcs <= 4:
return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4))
else:
return ', '.join('src' + str(i) for i in range(num_srcs))
%>
% for name, opcode in sorted(opcodes.items()):
static inline nir_ssa_def *
nir_${name}(nir_builder *build, ${src_decl_list(opcode.num_inputs)})
{
% if opcode.num_inputs <= 4:
return nir_build_alu(build, nir_op_${name}, ${src_list(opcode.num_inputs)});
% else:
nir_ssa_def *srcs[${opcode.num_inputs}] = {${src_list(opcode.num_inputs)}};
return nir_build_alu_src_arr(build, nir_op_${name}, srcs);
% endif
}
% endfor
......
......@@ -292,6 +292,18 @@ struct ${type}${width}_vec {
${type}${width}_t y;
${type}${width}_t z;
${type}${width}_t w;
${type}${width}_t e;
${type}${width}_t f;
${type}${width}_t g;
${type}${width}_t h;
${type}${width}_t i;
${type}${width}_t j;
${type}${width}_t k;
${type}${width}_t l;
${type}${width}_t m;
${type}${width}_t n;
${type}${width}_t o;
${type}${width}_t p;
};
% endfor
% endfor
......@@ -324,7 +336,7 @@ struct ${type}${width}_vec {
_src[${j}][${k}].${get_const_field(input_types[j])},
% endif
% endfor
% for k in range(op.input_sizes[j], 4):
% for k in range(op.input_sizes[j], 16):
0,
% endfor
};
......@@ -418,18 +430,18 @@ struct ${type}${width}_vec {
% for k in range(op.output_size):
% if output_type == "int1" or output_type == "uint1":
/* 1-bit integers get truncated */
_dst_val[${k}].b = dst.${"xyzw"[k]} & 1;
_dst_val[${k}].b = dst.${"xyzwefghijklmnop"[k]} & 1;
% elif output_type.startswith("bool"):
## Sanitize the C value to a proper NIR 0/-1 bool
_dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzw"[k]};
_dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzwefghijklmnop"[k]};
% elif output_type == "float16":
if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
_dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzw"[k]});
_dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzwefghijklmnop"[k]});
} else {
_dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzw"[k]});
_dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzwefghijklmnop"[k]});
}
% else:
_dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzw"[k]};
_dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzwefghijklmnop"[k]};
% endif
% if op.name != "fquantize2f16" and type_base_type(output_type) == "float":
......
......@@ -117,6 +117,8 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
return lower_reduction(alu, chan, merge, b); \
switch (alu->op) {
case nir_op_vec16:
case nir_op_vec8:
case nir_op_vec4:
case nir_op_vec3:
case nir_op_vec2:
......
......@@ -56,6 +56,8 @@ lower_alu_instr(nir_builder *b, nir_alu_instr *alu)
case nir_op_vec2:
case nir_op_vec3:
case nir_op_vec4:
case nir_op_vec8:
case nir_op_vec16:
/* These we expect to have booleans but the opcode doesn't change */
break;
......
......@@ -53,6 +53,8 @@ lower_alu_instr(nir_alu_instr *alu)
case nir_op_vec2:
case nir_op_vec3:
case nir_op_vec4:
case nir_op_vec8:
case nir_op_vec16:
case nir_op_inot:
case nir_op_iand:
case nir_op_ior:
......
......@@ -75,7 +75,7 @@ class Opcode(object):
assert isinstance(algebraic_properties, str)
assert isinstance(const_expr, str)
assert len(input_sizes) == len(input_types)
assert 0 <= output_size <= 4
assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
for size in input_sizes:
assert 0 <= size <= 4
if output_size != 0:
......@@ -1057,6 +1057,40 @@ dst.z = src2.x;
dst.w = src3.x;
""")
opcode("vec8", 8, tuint,
[1] * 8, [tuint] * 8,
False, "", """
dst.x = src0.x;
dst.y = src1.x;
dst.z = src2.x;
dst.w = src3.x;
dst.e = src4.x;
dst.f = src5.x;
dst.g = src6.x;
dst.h = src7.x;
""")
opcode("vec16", 16, tuint,
[1] * 16, [tuint] * 16,
False, "", """
dst.x = src0.x;
dst.y = src1.x;
dst.z = src2.x;
dst.w = src3.x;
dst.e = src4.x;
dst.f = src5.x;
dst.g = src6.x;
dst.h = src7.x;
dst.i = src8.x;
dst.j = src9.x;
dst.k = src10.x;
dst.l = src11.x;
dst.m = src12.x;
dst.n = src13.x;
dst.o = src14.x;
dst.p = src15.x;
""")
# An integer multiply instruction for address calculation. This is
# similar to imul, except that the results are undefined in case of
# overflow. Overflow is defined according to the size of the variable
......
......@@ -643,7 +643,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
return false;
unsigned new_num_components = size / new_bit_size;
if (new_num_components > NIR_MAX_VEC_COMPONENTS)
if (!nir_num_components_valid(new_num_components))
return false;
unsigned high_offset = high->offset_signed - low->offset_signed;
......
......@@ -171,6 +171,12 @@ print_dest(nir_dest *dest, print_state *state)
print_reg_dest(&dest->reg, state);
}
static const char *
comp_mask_string(unsigned num_components)
{
return (num_components > 4) ? "abcdefghijklmnop" : "xyzw";
}
static void
print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
{
......@@ -206,7 +212,7 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
if (!nir_alu_instr_channel_used(instr, src, i))
continue;
fprintf(fp, "%c", "xyzw"[instr->src[src].swizzle[i]]);
fprintf(fp, "%c", comp_mask_string(live_channels)[instr->src[src].swizzle[i]]);
}
}
......@@ -224,10 +230,11 @@ print_alu_dest(nir_alu_dest *dest, print_state *state)
if (!dest->dest.is_ssa &&
dest->write_mask != (1 << dest->dest.reg.reg->num_components) - 1) {
unsigned live_channels = dest->dest.reg.reg->num_components;
fprintf(fp, ".");
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
if ((dest->write_mask >> i) & 1)
fprintf(fp, "%c", "xyzw"[i]);
fprintf(fp, "%c", comp_mask_string(live_channels)[i]);
}
}
......@@ -569,8 +576,8 @@ print_var_decl(nir_variable *var, print_state *state)
switch (var->data.mode) {
case nir_var_shader_in:
case nir_var_shader_out:
if (num_components < 4 && num_components != 0) {
const char *xyzw = "xyzw";
if (num_components < 16 && num_components != 0) {
const char *xyzw = comp_mask_string(num_components);
for (int i = 0; i < num_components; i++)
components_local[i + 1] = xyzw[i + var->data.location_frac];
......@@ -816,9 +823,9 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
/* special case wrmask to show it as a writemask.. */
unsigned wrmask = nir_intrinsic_write_mask(instr);
fprintf(fp, " wrmask=");
for (unsigned i = 0; i < 4; i++)
for (unsigned i = 0; i < instr->num_components; i++)
if ((wrmask >> i) & 1)
fprintf(fp, "%c", "xyzw"[i]);
fprintf(fp, "%c", comp_mask_string(instr->num_components)[i]);
break;
}
......
......@@ -56,7 +56,13 @@ static bool
nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
const struct per_op_table *pass_op_table);
static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
{
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
};
/**
* Check if a source produces a value of the given type.
......
......@@ -128,8 +128,7 @@ static void validate_src(nir_src *src, validate_state *state,
static void
validate_num_components(validate_state *state, unsigned num_components)
{
validate_assert(state, num_components >= 1 &&
num_components <= 4);
validate_assert(state, nir_num_components_valid(num_components));
}
static void
......
......@@ -3819,10 +3819,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvCapabilityInputAttachment:
case SpvCapabilityImageGatherExtended:
case SpvCapabilityStorageImageExtendedFormats:
case SpvCapabilityVector16:
break;
case SpvCapabilityLinkage:
case SpvCapabilityVector16:
case SpvCapabilityFloat16Buffer:
case SpvCapabilitySparseResidency:
vtn_warn("Unsupported SPIR-V capability: %s",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment