Commit 0db124fa authored by Connor Abbott's avatar Connor Abbott
Browse files

bifrost: Add exponent

parent 6485cd7d
......@@ -884,7 +884,7 @@ output = FMA_MSCALE.sqrt_mode a1, r1, -0, e
== Natural & Base-2 Log (G71)
As described in the patent, the argument reduction step for log depends on the fact that `log(a * b) = log(a) + log(b)`. First, we reduce the input `a` to the range [0.75, 1.5) and calculate the required fixup exponent using the `FLOG_FREPXE` instruction. This fixup exponent is converted to a floating-point number, and is then added to the result later (after multiplying by ln(2) for base-e log), using the fact that `log(m * 2^e) = e + log(m)`. We lookup an very coarse-grained reciprocal estimate `r1` to the reduced input `a1` in a table using the `FRCP_APPROX` instruction, and we also lookup `-log(r1)` using another table that stores for each input `x`, `log(FRCP_APPROX(x))`. Both these instructions take `a`, and do the reduction to [0.75, 1.5) implicitly. Since `log(a1) = log(a1 * r1) - log(r1)`, all we need to do to compute `log(a1)` is compute `log(a1 * r1)`. Since `a1 * r1` is close to 1 by design, we can do this with a polynomial approximation of `log(y + 1)` with only a few terms. The approximation chosen for base _e_ is `log(y + 1) = y * (1.0 + y * (a + y * b))` where the precise constants used are `b = 0x3eab3200 (0.33436)` and `a = 0xbf0003f0 (-0.500060)`, close to the Taylor series terms 1/3 and -1/2 but not exactly the same. The base 2 logarithm computes `y' = y * c` where `c` is a single-precision approximation to `1/ln(2)` and then `log2(y + 1) = y' + y * (c + y' * (y * a + b))` where `d = 0x32a57060 (1.29249 * 2^{-26})`. `d` holds the error due to approximating `1/ln(2)` as a single-precision floating point number, so that `c + d`, when evaluated at infinite precision, is a much better approximation to `1/ln(2)` than `c` by itself.
As described in the patent, the argument reduction step for log depends on the fact that `log(a * b) = log(a) + log(b)`. First, we reduce the input `a` to the range [0.75, 1.5) and calculate the required fixup exponent using the `FLOG_FREPXE` instruction. This fixup exponent is converted to a floating-point number, and is then added to the result later (after multiplying by ln(2) for base-e log), using the fact that `log(m * 2^e) = e + log(m)`. We lookup a very coarse-grained reciprocal estimate `r1` to the reduced input `a1` in a table using the `FRCP_APPROX` instruction, and we also lookup `-log(r1)` using another table that stores for each input `x`, `-log(FRCP_APPROX(x))`. Both these instructions take `a`, and do the reduction to [0.75, 1.5) implicitly. Since `log(a1) = log(a1 * r1) - log(r1)`, all we need to do to compute `log(a1)` is compute `log(a1 * r1)`. Since `a1 * r1` is close to 1 by design, we can do this with a polynomial approximation of `log(y + 1)` with only a few terms. The approximation chosen for base _e_ is `log(y + 1) = y * (1.0 + y * (a + y * b))` where the precise constants used are `b = 0x3eab3200 (0.33436)` and `a = 0xbf0003f0 (-0.500060)`, close to the Taylor series terms 1/3 and -1/2 but not exactly the same. The base 2 logarithm computes `y' = y * c` where `c` is a single-precision approximation to `1/ln(2)` and then `log2(y + 1) = y' + y * (c + y' * (y * a + b))` where `d = 0x32a57060 (1.29249 * 2^{-26})`. `d` holds the error due to approximating `1/ln(2)` as a single-precision floating point number, so that `c + d`, when evaluated at infinite precision, is a much better approximation to `1/ln(2)` than `c` by itself.
base 2 logarithm:
......@@ -922,4 +922,67 @@ output = FMA.f32 {R0, T0}, y, t2, x1
== Natural & Base-2 Logarithm (later)
== Exponent (G71)
To compute the exponent, we want to reduce the input latexmath:[a] into a multiple latexmath:[a_1] of latexmath:[\log_2(b)], where latexmath:[b] here is the base, plus a remainder latexmath:[a_2], since latexmath:[b^a = 2^{a_1} b^{a_2}], and as usual, the final multiplication by latexmath:[2^{a_1}] can be done easily using a `*_MSCALE` instruction. In fact, to reduce the range required for the approximation to latexmath:[b^{a_2}], we can go even farther and decompose the input into a multiple of latexmath:[2^{-4} \log_2(b)] instead. latexmath:[2^{a_1}] can be computed using a small lookup table for the 4 fractional bits, plus the usual post-correction for the integral bits. Finally, we compute latexmath:[b^{a_2}] using a polynomial approximation.
In order to do the decomposition, we use a well-known floating-point trick. Letting latexmath:[m] be the number of mantissa bits, we multiply latexmath:[a] by latexmath:[\log_2(b)^{-1}] and then add latexmath:[f = 1.5 \times 2^{m-4}]. If we assume that latexmath:[-2^{m-5} \leq a \times \log_2(b)^{-1} < 2^{m-5}], then the result will be between latexmath:[2^{m-4}] and latexmath:[2^{m-3}]. In particular, the floating-point exponent will be latexmath:[m-4], and the mantissa will be latexmath:[2^{m-3} + a_1]. We can recover latexmath:[a_1] as an integer by reinterpreting the floating-point number as an integer and subtracting latexmath:[1.5 \times 2^{m-4}] reinterpreted as an integer, since this will remove the exponent bits and remove the bias of latexmath:[2^{m-3}] from the mantissa bits. In order to extract the remainder latexmath:[a_2], we undo the addition and scaling, which will produce latexmath:[2^{-4} \times \log_2(b) \times a_1], and then subtract the result from latexmath:[a]. The result of this addition is clamped from -1 to 1, presumably to prevent too-large offsets.
There are two special cases to worry about: when latexmath:[a < -2^{m-5}], and when latexmath:[a > 2^{m-5}]. In the first case, as long as the result of adding latexmath:[f] is positive, then the exponent will be smaller, so our latexmath:[a_1] will be a large negative number, and after adding it to the exponent, we will get 0 as desired. However, if the result is negative, we will get a large positive number due to the sign bit, which is not what we want. So we need to clamp the result so that it does not go below 0 (thankfully, there is the `.clamp_0_inf` modifier which does this without any additional instructions). If latexmath:[a > 2^{m-5}], then the exponent will be larger than expected, so we'll get a large positive number for latexmath:[a_1] which will give us infinity as desired.
The blob actually clamps the result to latexmath:[2^{m + 64}], using the `ADD_MSCALE` instruction to scale by latexmath:[2^{-64}] after adding latexmath:[f] combined with the `.clamp_0_1` modifier to then clamp the result between 0 and 1. It also adjusts the subsequent steps to account for the extra scaling. However, tests indicate that this is unnecessary.
The polynomial approximation used for base 2 is latexmath:[(a_2 \times ((a_2 \times a + b) \times a_2 + c)) \times 2^{a_{1f}} + 2^{a_{1f}}], where the constants are `a = 0x3d635635 (0.05502)`, `b = 0x3e75fffa (0.240234)`, and `c = 0x3f317218 (0.693147)`. Note how the final multiplication by latexmath:[2^{a_{1f}}], the result of the table lookup on the fractional part of latexmath:[a_1], has been folded into the expression. The polynomial approximation for base _e_ is latexmath:[(a_2 \times a_2 \times (a_2 \times a + b) + a_2) \times 2^{a_{1f}} + 2^{a_{1f}}], where `a = 0x3e2aaacd (0.166667)` and `b = 0x3f00010e (0.500016)`.
The final fused multiply-add is actually an `FMA_MSCALE` instruction, with the exponent bias from the integer part of latexmath:[a_1] added in. In addition, the result is clamped from going below 0, presumably to prevent any small errors from making the result go below 0 for very negative inputs. Finally, since all the clamping done earlier flushes NaN's to zero, but we want the output to be NaN if the input is NaN, we take the maximum of the original input and the output. Normally this wouldn't do anything, since latexmath:[e^x > x] and latexmath:[2^x > x], but we use a special `.nan_wins` modifier which makes sure that the output is NaN if either input is NaN, making sure that the NaN is propagated correctly.
Base 2 exponent as implemented by the blob:
[source]
----
t1 = ADD_MSCALE.f32.clamp_0_1 a, 0x49400000 /* 786432.000000 */, -0x40
t2 = ADD_MSCALE.f32 t1, 0xa9400000, 0x40
a2 = ADD.f32.clamp_m1_1 a, -t2
a1t = EXP_TABLE t1
t3 = SUB.i32 t1, 0x29400000 /* 0.000000 */
a1i = ARSHIFT t3, 4
p1 = FMA.f32 a2, 0x3d635635 /* 0.055502 */, 0x3e75fffa /* 0.240234 */
p2 = FMA.f32 p1, a2, 0x3f317218 /* 0.693147 */
p3 = FMA.f32 a2, p2, -0
x = FMA_MSCALE.clamp_0_inf p3, a1t, a1t, a1i
x' = MAX.f32.nan_wins x, a
----
However, the following version was tested to return the same result for every possible input, and is a little simpler:
[source]
----
t1 = ADD.f32.clamp_0_inf a, 0x49400000 /* 786432.000000 */
t2 = ADD.f32 t1, 0xc9400000
a2 = ADD.f32.clamp_m1_1 a, -t2
a1t = EXP_TABLE t1
t3 = SUB.i32 t1, 0x49400000 /* 0.000000 */
a1i = ARSHIFT t3, 4
p1 = FMA.f32 a2, 0x3d635635 /* 0.055502 */, 0x3e75fffa /* 0.240234 */
p2 = FMA.f32 p1, a2, 0x3f317218 /* 0.693147 */
p3 = FMA.f32 a2, p2, -0
x = FMA_MSCALE.clamp_0_inf p3, a1t, a1t, a1i
x' = MAX.f32.nan_wins x, a
----
Base _e_ exponent:
[source]
----
t1 = FMA.f32.clamp_0_1 a, 0x31b8aa3b, 0x3b400000
t2 = ADD.f32 t1, 0xbb400000
a2 = FMA_MSCALE.clamp_m1_1 t2, 0xcd317218 /* -186065280.000000 */, a, 0
a1t = EXP_TABLE t1
a1i = ARSHIFT t3, 4
t3 = SUB.i32 t1, 0x3b400000
p1 = FMA.f32 a2, 0x3e2aaacd /* 0.166667 */, 0x3f00010e /* 0.500016 */
p2 = FMA.f32 a2, p2, -0
p3 = FMA.f32 a2, p2, a2
x = FMA_MSCALE.clamp_0_inf p3, a1t, a1t, a1i
x' = MAX.f32.nan_wins x, a
----
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment