Commit 29a1450e authored by Connor Abbott's avatar Connor Abbott

nir/algebraic: Rewrite bit-size inference

Before this commit, there were two copies of the algorithm: one in C,
that we would use to figure out what bit-size to give the replacement
expression, and one in Python, that emulated the C one and tried to
prove that the C algorithm would never fail to correctly assign
bit-sizes. That seemed pretty fragile, and likely to fall over if we
make any changes. Furthermore, the C code was really just recomputing
more-or-less the same thing as the Python code every time. Instead, we
can just store the results of the Python algorithm in the C
datastructure, and consult it to compute the bitsize of each value,
moving the "brains" entirely into Python. Since the Python algorithm no
longer has to match C, it's also a lot easier to change it to something
more closely approximating an actual type-inference algorithm. The
algorithm used is based on Hindley-Milner, although deliberately
weakened a little. It's a few more lines than the old one, judging by
the diffstat, but I think it's easier to verify that it's correct while
being as general as possible.

We could split this up into two changes, first making the C code use the
results of the Python code and then rewriting the Python algorithm, but
since the old algorithm never tracked which variable each equivalence
class, it would mean we'd have to add some non-trivial code which would
then get thrown away. I think it's better to see the final state all at
once, although I could also try splitting it up.

v2:
- Replace instances of "== None" and "!= None" with "is None" and
"is not None".
- Rename first_src to first_unsized_src
- Only merge the destination with the first unsized source, since the
sources have already been merged.
- Add a comment explaining what nir_search_value::bit_size now means.
v3:
- Fix one last instance to use "is not" instead of !=
- Don't try to be so clever when choosing which error message to print
based on whether we're in the search or replace expression.
- Fix trailing whitespace.
Reviewed-by: Jason Ekstrand's avatarJason Ekstrand <jason@jlekstrand.net>
Reviewed-by: Dylan Baker's avatarDylan Baker <dylan@pnwbakers.com>
parent 49ef8907
This diff is collapsed.
......@@ -118,7 +118,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
/* If the value has a specific bit size and it doesn't match, bail */
if (value->bit_size &&
if (value->bit_size > 0 &&
nir_src_bit_size(instr->src[src].src) != value->bit_size)
return false;
......@@ -228,7 +228,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
assert(instr->dest.dest.is_ssa);
if (expr->value.bit_size &&
if (expr->value.bit_size > 0 &&
instr->dest.dest.ssa.bit_size != expr->value.bit_size)
return false;
......@@ -290,128 +290,21 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
}
}
typedef struct bitsize_tree {
unsigned num_srcs;
struct bitsize_tree *srcs[4];
unsigned common_size;
bool is_src_sized[4];
bool is_dest_sized;
unsigned dest_size;
unsigned src_size[4];
} bitsize_tree;
static bitsize_tree *
build_bitsize_tree(void *mem_ctx, struct match_state *state,
const nir_search_value *value)
{
bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
switch (value->type) {
case nir_search_value_expression: {
nir_search_expression *expr = nir_search_value_as_expression(value);
nir_op_info info = nir_op_infos[expr->opcode];
tree->num_srcs = info.num_inputs;
tree->common_size = 0;
for (unsigned i = 0; i < info.num_inputs; i++) {
tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
if (tree->is_src_sized[i])
tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
}
tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
if (tree->is_dest_sized)
tree->dest_size = nir_alu_type_get_type_size(info.output_type);
break;
}
case nir_search_value_variable: {
nir_search_variable *var = nir_search_value_as_variable(value);
tree->num_srcs = 0;
tree->is_dest_sized = true;
tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
break;
}
case nir_search_value_constant: {
tree->num_srcs = 0;
tree->is_dest_sized = false;
tree->common_size = 0;
break;
}
}
if (value->bit_size) {
assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
tree->common_size = value->bit_size;
}
return tree;
}
static unsigned
bitsize_tree_filter_up(bitsize_tree *tree)
replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
struct match_state *state)
{
for (unsigned i = 0; i < tree->num_srcs; i++) {
unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
if (src_size == 0)
continue;
if (tree->is_src_sized[i]) {
assert(src_size == tree->src_size[i]);
} else if (tree->common_size != 0) {
assert(src_size == tree->common_size);
tree->src_size[i] = src_size;
} else {
tree->common_size = src_size;
tree->src_size[i] = src_size;
}
}
if (tree->num_srcs && tree->common_size) {
if (tree->dest_size == 0)
tree->dest_size = tree->common_size;
else if (!tree->is_dest_sized)
assert(tree->dest_size == tree->common_size);
for (unsigned i = 0; i < tree->num_srcs; i++) {
if (!tree->src_size[i])
tree->src_size[i] = tree->common_size;
}
}
return tree->dest_size;
}
static void
bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
{
if (tree->dest_size)
assert(tree->dest_size == size);
else
tree->dest_size = size;
if (!tree->is_dest_sized) {
if (tree->common_size)
assert(tree->common_size == size);
else
tree->common_size = size;
}
for (unsigned i = 0; i < tree->num_srcs; i++) {
if (!tree->src_size[i]) {
assert(tree->common_size);
tree->src_size[i] = tree->common_size;
}
bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
}
if (value->bit_size > 0)
return value->bit_size;
if (value->bit_size < 0)
return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
return search_bitsize;
}
static nir_alu_src
construct_value(nir_builder *build,
const nir_search_value *value,
unsigned num_components, bitsize_tree *bitsize,
unsigned num_components, unsigned search_bitsize,
struct match_state *state,
nir_instr *instr)
{
......@@ -424,7 +317,7 @@ construct_value(nir_builder *build,
nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);
nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
bitsize->dest_size, NULL);
replace_bitsize(value, search_bitsize, state), NULL);
alu->dest.write_mask = (1 << num_components) - 1;
alu->dest.saturate = false;
......@@ -443,7 +336,7 @@ construct_value(nir_builder *build,
num_components = nir_op_infos[alu->op].input_sizes[i];
alu->src[i] = construct_value(build, expr->srcs[i],
num_components, bitsize->srcs[i],
num_components, search_bitsize,
state, instr);
}
......@@ -472,16 +365,17 @@ construct_value(nir_builder *build,
case nir_search_value_constant: {
const nir_search_constant *c = nir_search_value_as_constant(value);
unsigned bit_size = replace_bitsize(value, search_bitsize, state);
nir_ssa_def *cval;
switch (c->type) {
case nir_type_float:
cval = nir_imm_floatN_t(build, c->data.d, bitsize->dest_size);
cval = nir_imm_floatN_t(build, c->data.d, bit_size);
break;
case nir_type_int:
case nir_type_uint:
cval = nir_imm_intN_t(build, c->data.i, bitsize->dest_size);
cval = nir_imm_intN_t(build, c->data.i, bit_size);
break;
case nir_type_bool:
......@@ -526,16 +420,12 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
swizzle, &state))
return NULL;
void *bitsize_ctx = ralloc_context(NULL);
bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
bitsize_tree_filter_up(tree);
bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
build->cursor = nir_before_instr(&instr->instr);
nir_alu_src val = construct_value(build, replace,
instr->dest.dest.ssa.num_components,
tree, &state, &instr->instr);
instr->dest.dest.ssa.bit_size,
&state, &instr->instr);
/* Inserting a mov may be unnecessary. However, it's much easier to
* simply let copy propagation clean this up than to try to go through
......@@ -551,7 +441,5 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
*/
nir_instr_remove(&instr->instr);
ralloc_free(bitsize_ctx);
return ssa_val;
}
......@@ -43,7 +43,22 @@ typedef enum {
typedef struct {
nir_search_value_type type;
unsigned bit_size;
/**
* Bit size of the value. It is interpreted as follows:
*
* For a search expression:
* - If bit_size > 0, then the value only matches an SSA value with the
* given bit size.
* - If bit_size <= 0, then the value matches any size SSA value.
*
* For a replace expression:
* - If bit_size > 0, then the value is constructed with the given bit size.
* - If bit_size == 0, then the value is constructed with the same bit size
* as the search value.
* - If bit_size < 0, then the value is constructed with the same bit size
* as variable (-bit_size - 1).
*/
int bit_size;
} nir_search_value;
typedef struct {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment