From aba020ef91f4cbef12a5aca9cac779b8d127096c Mon Sep 17 00:00:00 2001
From: Emma Anholt <emma@anholt.net>
Date: Fri, 21 Feb 2025 10:54:04 -0800
Subject: [PATCH] intel: Use the common NIR lowering for fquantize2f16.

This generates one extra instruction to set the rounding mode to RTE due
to f2f16_rtne in the lowering.  This changes the result for
fquantize2f16(65505.0) from 65536 to 65504, which is perhaps better but
still incorrect (should be +inf).
---
 src/intel/compiler/brw_compiler.c        |  1 +
 src/intel/compiler/brw_from_nir.cpp      | 24 ----------------------
 src/intel/compiler/elk/elk_fs_nir.cpp    | 26 ------------------------
 src/intel/compiler/elk/elk_nir_options.c |  1 +
 src/intel/compiler/elk/elk_vec4_nir.cpp  | 24 ----------------------
 5 files changed, 2 insertions(+), 74 deletions(-)

diff --git a/src/intel/compiler/brw_compiler.c b/src/intel/compiler/brw_compiler.c
index fe44a6b09f1f7..4f2c1724f2c8a 100644
--- a/src/intel/compiler/brw_compiler.c
+++ b/src/intel/compiler/brw_compiler.c
@@ -52,6 +52,7 @@ const struct nir_shader_compiler_options brw_scalar_nir_options = {
    .lower_flrp16 = true,
    .lower_flrp64 = true,
    .lower_fmod = true,
+   .lower_fquantize2f16 = true,
    .lower_hadd64 = true,
    .lower_insert_byte = true,
    .lower_insert_word = true,
diff --git a/src/intel/compiler/brw_from_nir.cpp b/src/intel/compiler/brw_from_nir.cpp
index e2498339546c9..ebbbd5733184c 100644
--- a/src/intel/compiler/brw_from_nir.cpp
+++ b/src/intel/compiler/brw_from_nir.cpp
@@ -1526,30 +1526,6 @@ brw_from_nir_emit_alu(nir_to_brw_state &ntb, nir_alu_instr *instr,
       bld.RNDE(result, op[0]);
       break;
 
-   case nir_op_fquantize2f16: {
-      brw_reg tmp16 = bld.vgrf(BRW_TYPE_D);
-      brw_reg tmp32 = bld.vgrf(BRW_TYPE_F);
-
-      /* The destination stride must be at least as big as the source stride. */
-      tmp16 = subscript(tmp16, BRW_TYPE_HF, 0);
-
-      /* Check for denormal */
-      brw_reg abs_src0 = op[0];
-      abs_src0.abs = true;
-      bld.CMP(bld.null_reg_f(), abs_src0, brw_imm_f(ldexpf(1.0, -14)),
-              BRW_CONDITIONAL_L);
-      /* Get the appropriately signed zero */
-      brw_reg zero = retype(bld.AND(retype(op[0], BRW_TYPE_UD),
-                                   brw_imm_ud(0x80000000)), BRW_TYPE_F);
-      /* Do the actual F32 -> F16 -> F32 conversion */
-      bld.MOV(tmp16, op[0]);
-      bld.MOV(tmp32, tmp16);
-      /* Select that or zero based on normal status */
-      inst = bld.SEL(result, zero, tmp32);
-      inst->predicate = BRW_PREDICATE_NORMAL;
-      break;
-   }
-
    case nir_op_imin:
    case nir_op_umin:
    case nir_op_fmin:
diff --git a/src/intel/compiler/elk/elk_fs_nir.cpp b/src/intel/compiler/elk/elk_fs_nir.cpp
index 6b78d6cf9f8eb..a6cb321a2be5d 100644
--- a/src/intel/compiler/elk/elk_fs_nir.cpp
+++ b/src/intel/compiler/elk/elk_fs_nir.cpp
@@ -1545,32 +1545,6 @@ fs_nir_emit_alu(nir_to_elk_state &ntb, nir_alu_instr *instr,
       }
       break;
 
-   case nir_op_fquantize2f16: {
-      elk_fs_reg tmp16 = bld.vgrf(ELK_REGISTER_TYPE_D);
-      elk_fs_reg tmp32 = bld.vgrf(ELK_REGISTER_TYPE_F);
-      elk_fs_reg zero = bld.vgrf(ELK_REGISTER_TYPE_F);
-
-      /* The destination stride must be at least as big as the source stride. */
-      tmp16 = subscript(tmp16, ELK_REGISTER_TYPE_HF, 0);
-
-      /* Check for denormal */
-      elk_fs_reg abs_src0 = op[0];
-      abs_src0.abs = true;
-      bld.CMP(bld.null_reg_f(), abs_src0, elk_imm_f(ldexpf(1.0, -14)),
-              ELK_CONDITIONAL_L);
-      /* Get the appropriately signed zero */
-      bld.AND(retype(zero, ELK_REGISTER_TYPE_UD),
-              retype(op[0], ELK_REGISTER_TYPE_UD),
-              elk_imm_ud(0x80000000));
-      /* Do the actual F32 -> F16 -> F32 conversion */
-      bld.F32TO16(tmp16, op[0]);
-      bld.F16TO32(tmp32, tmp16);
-      /* Select that or zero based on normal status */
-      inst = bld.SEL(result, zero, tmp32);
-      inst->predicate = ELK_PREDICATE_NORMAL;
-      break;
-   }
-
    case nir_op_imin:
    case nir_op_umin:
    case nir_op_fmin:
diff --git a/src/intel/compiler/elk/elk_nir_options.c b/src/intel/compiler/elk/elk_nir_options.c
index 59141d82a18ae..81c0362804b9b 100644
--- a/src/intel/compiler/elk/elk_nir_options.c
+++ b/src/intel/compiler/elk/elk_nir_options.c
@@ -18,6 +18,7 @@
    .lower_usub_borrow = true,                                                 \
    .lower_flrp64 = true,                                                      \
    .lower_fisnormal = true,                                                   \
+   .lower_fquantize2f16 = true,                                               \
    .lower_isign = true,                                                       \
    .lower_ldexp = true,                                                       \
    .lower_bitfield_extract = true,                                            \
diff --git a/src/intel/compiler/elk/elk_vec4_nir.cpp b/src/intel/compiler/elk/elk_vec4_nir.cpp
index abe0ba2b962aa..d258ec2b5c46e 100644
--- a/src/intel/compiler/elk/elk_vec4_nir.cpp
+++ b/src/intel/compiler/elk/elk_vec4_nir.cpp
@@ -1355,30 +1355,6 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       }
       break;
 
-   case nir_op_fquantize2f16: {
-      /* See also vec4_visitor::emit_pack_half_2x16() */
-      src_reg tmp16 = src_reg(this, glsl_uvec4_type());
-      src_reg tmp32 = src_reg(this, glsl_vec4_type());
-      src_reg zero = src_reg(this, glsl_vec4_type());
-
-      /* Check for denormal */
-      src_reg abs_src0 = op[0];
-      abs_src0.abs = true;
-      emit(CMP(dst_null_f(), abs_src0, elk_imm_f(ldexpf(1.0, -14)),
-               ELK_CONDITIONAL_L));
-      /* Get the appropriately signed zero */
-      emit(AND(retype(dst_reg(zero), ELK_REGISTER_TYPE_UD),
-               retype(op[0], ELK_REGISTER_TYPE_UD),
-               elk_imm_ud(0x80000000)));
-      /* Do the actual F32 -> F16 -> F32 conversion */
-      emit(F32TO16(dst_reg(tmp16), op[0]));
-      emit(F16TO32(dst_reg(tmp32), tmp16));
-      /* Select that or zero based on normal status */
-      inst = emit(ELK_OPCODE_SEL, dst, zero, tmp32);
-      inst->predicate = ELK_PREDICATE_NORMAL;
-      break;
-   }
-
    case nir_op_imin:
    case nir_op_umin:
       assert(instr->def.bit_size < 64);
-- 
GitLab