Commit 05af952a authored by Jason Ekstrand's avatar Jason Ekstrand

nir/algebraic: Add support for unsized conversion opcodes

All conversion opcodes require a destination size but this makes
constructing certain algebraic expressions rather cumbersome.  This
commit adds support to nir_search and nir_algebraic for writing
conversion opcodes without a size.  These meta-opcodes match any
conversion of that type regardless of destination size and the size gets
inferred from the sizes of the things being matched or from other
opcodes in the expression.
Reviewed-by: Connor Abbott's avatarConnor Abbott <cwabbott0@gmail.com>
parent 4925290a
......@@ -33,7 +33,19 @@ import mako.template
import re
import traceback
from nir_opcodes import opcodes
from nir_opcodes import opcodes, type_sizes
# These opcodes are only employed by nir_search. This provides a mapping from
# opcode to destination type.
conv_opcode_types = {
'i2f' : 'float',
'u2f' : 'float',
'f2f' : 'float',
'f2u' : 'uint',
'f2i' : 'int',
'u2u' : 'uint',
'i2i' : 'int',
}
if sys.version_info < (3, 0):
integer_types = (int, long)
......@@ -98,7 +110,7 @@ static const ${val.c_type} ${val.name} = {
${val.cond if val.cond else 'NULL'},
% elif isinstance(val, Expression):
${'true' if val.inexact else 'false'},
nir_op_${val.opcode},
${val.c_opcode()},
{ ${', '.join(src.c_ptr for src in val.sources)} },
${val.cond if val.cond else 'NULL'},
% endif
......@@ -276,6 +288,18 @@ class Expression(Value):
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
for (i, src) in enumerate(expr[1:]) ]
if self.opcode in conv_opcode_types:
assert self._bit_size is None, \
'Expression cannot use an unsized conversion opcode with ' \
'an explicit size; that\'s silly.'
def c_opcode(self):
if self.opcode in conv_opcode_types:
return 'nir_search_op_' + self.opcode
else:
return 'nir_op_' + self.opcode
def render(self):
srcs = "\n".join(src.render() for src in self.sources)
return srcs + super(Expression, self).render()
......@@ -462,6 +486,17 @@ class BitSizeValidator(object):
if not isinstance(val, Expression):
return
# Generic conversion ops are special in that they have a single unsized
# source and an unsized destination and the two don't have to match.
# This means there's no validation or unioning to do here besides the
# len(val.sources) check.
if val.opcode in conv_opcode_types:
assert len(val.sources) == 1, \
"Expression {} has {} sources, expected 1".format(
val, len(val.sources))
self.validate_value(val.sources[0])
return
nir_op = opcodes[val.opcode]
assert len(val.sources) == nir_op.num_inputs, \
"Expression {} has {} sources, expected {}".format(
......@@ -732,7 +767,13 @@ class AlgebraicPass(object):
continue
self.xforms.append(xform)
self.opcode_xforms[xform.search.opcode].append(xform)
if xform.search.opcode in conv_opcode_types:
dst_type = conv_opcode_types[xform.search.opcode]
for size in type_sizes(dst_type):
sized_opcode = xform.search.opcode + str(size)
self.opcode_xforms[sized_opcode].append(xform)
else:
self.opcode_xforms[xform.search.opcode].append(xform)
if error:
sys.exit(1)
......
......@@ -89,6 +89,82 @@ src_is_type(nir_src src, nir_alu_type type)
return false;
}
static bool
nir_op_matches_search_op(nir_op nop, uint16_t sop)
{
if (sop <= nir_last_opcode)
return nop == sop;
#define MATCH_FCONV_CASE(op) \
case nir_search_op_##op: \
return nop == nir_op_##op##16 || \
nop == nir_op_##op##32 || \
nop == nir_op_##op##64;
#define MATCH_ICONV_CASE(op) \
case nir_search_op_##op: \
return nop == nir_op_##op##8 || \
nop == nir_op_##op##16 || \
nop == nir_op_##op##32 || \
nop == nir_op_##op##64;
switch (sop) {
MATCH_FCONV_CASE(i2f)
MATCH_FCONV_CASE(u2f)
MATCH_FCONV_CASE(f2f)
MATCH_ICONV_CASE(f2u)
MATCH_ICONV_CASE(f2i)
MATCH_ICONV_CASE(u2u)
MATCH_ICONV_CASE(i2i)
default:
unreachable("Invalid nir_search_op");
}
#undef MATCH_FCONV_CASE
#undef MATCH_ICONV_CASE
}
static nir_op
nir_op_for_search_op(uint16_t sop, unsigned bit_size)
{
if (sop <= nir_last_opcode)
return sop;
#define RET_FCONV_CASE(op) \
case nir_search_op_##op: \
switch (bit_size) { \
case 16: return nir_op_##op##16; \
case 32: return nir_op_##op##32; \
case 64: return nir_op_##op##64; \
default: unreachable("Invalid bit size"); \
}
#define RET_ICONV_CASE(op) \
case nir_search_op_##op: \
switch (bit_size) { \
case 8: return nir_op_##op##8; \
case 16: return nir_op_##op##16; \
case 32: return nir_op_##op##32; \
case 64: return nir_op_##op##64; \
default: unreachable("Invalid bit size"); \
}
switch (sop) {
RET_FCONV_CASE(i2f)
RET_FCONV_CASE(u2f)
RET_FCONV_CASE(f2f)
RET_ICONV_CASE(f2u)
RET_ICONV_CASE(f2i)
RET_ICONV_CASE(u2u)
RET_ICONV_CASE(i2i)
default:
unreachable("Invalid nir_search_op");
}
#undef RET_FCONV_CASE
#undef RET_ICONV_CASE
}
static bool
match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
unsigned num_components, const uint8_t *swizzle,
......@@ -223,7 +299,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
if (expr->cond && !expr->cond(instr))
return false;
if (instr->op != expr->opcode)
if (!nir_op_matches_search_op(instr->op, expr->opcode))
return false;
assert(instr->dest.dest.is_ssa);
......@@ -311,13 +387,15 @@ construct_value(nir_builder *build,
switch (value->type) {
case nir_search_value_expression: {
const nir_search_expression *expr = nir_search_value_as_expression(value);
unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
if (nir_op_infos[expr->opcode].output_size != 0)
num_components = nir_op_infos[expr->opcode].output_size;
if (nir_op_infos[op].output_size != 0)
num_components = nir_op_infos[op].output_size;
nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);
nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
replace_bitsize(value, search_bitsize, state), NULL);
dst_bit_size, NULL);
alu->dest.write_mask = (1 << num_components) - 1;
alu->dest.saturate = false;
......@@ -328,7 +406,7 @@ construct_value(nir_builder *build,
*/
alu->exact = state->has_exact_alu;
for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
/* If the source is an explicitly sized source, then we need to reset
* the number of components to match.
*/
......
......@@ -109,6 +109,16 @@ typedef struct {
} data;
} nir_search_constant;
enum nir_search_op {
nir_search_op_i2f = nir_last_opcode + 1,
nir_search_op_u2f,
nir_search_op_f2f,
nir_search_op_f2u,
nir_search_op_f2i,
nir_search_op_u2u,
nir_search_op_i2i,
};
typedef struct {
nir_search_value value;
......@@ -118,7 +128,8 @@ typedef struct {
*/
bool inexact;
nir_op opcode;
/* One of nir_op or nir_search_op */
uint16_t opcode;
const nir_search_value *srcs[4];
/** Optional condition fxn ptr
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment