Mathemagic finale: muldiv

A couple years back I wrote a series of blogs introducing some tricky mathemagic, but I never got around to tell the punchline. The goal is to compute

$$ \mathtt{muldiv}\p{\mathtt{a}, \mathtt{b}, \mathtt{denominator}} = \floor{\frac{\mathtt{a} ⋅ \mathtt{b}}{\mathtt{denominator}}} $$

while handling overflow correctly. Our secret weapon is EVM's mulmod instruction. This instruction does exactly what we want, except it returns the remainder not the quotient. So what's our strategy?

  1. Compute the 512-bit product $\mathtt{a} ⋅ \mathtt{b}$ using mulmod.
  2. Make the division exact by subtracting the remainder using mulmod.
  3. Remove powers of two from the fraction to make denominator invertible $\operatorname{mod} 2^{256}$.
  4. Compute the modular inverse of the denominator.
  5. Multiply numerator and inverse denominator $\operatorname{mod} 2^{256}$.

I will not explain the implementation of each step here, for that please refer to previous posts:

Implemented and optimized the strategy is as follows:

contract MulDiv {
    function muldiv(uint256 a, uint256 b, uint256 denominator)
    internal pure returns (uint256 result)
    {
        // Handle division by zero
        require(denominator > 0);

        // 512-bit multiply [prod1 prod0] = a * b
        // Compute the product mod 2**256 and mod 2**256 - 1
        // then use the Chinese Remiander Theorem to reconstruct
        // the 512 bit result. The result is stored in two 256
        // variables such that product = prod1 * 2**256 + prod0
        uint256 prod0; // Least significant 256 bits of the product
        uint256 prod1; // Most significant 256 bits of the product
        assembly {
            let mm := mulmod(a, b, not(0))
            prod0 := mul(a, b)
            prod1 := sub(sub(mm, prod0), lt(mm, prod0))
        }
        
        // Short circuit 256 by 256 division
        // This saves gas when a * b is small, at the cost of making the
        // large case a bit more expensive. Depending on your use case you
        // may want to remove this short circuit and always go through the
        // 512 bit path.
        if (prod1 == 0) {
            assembly {
                result := div(prod0, denominator)
            }
            return result;
        }
        
        ///////////////////////////////////////////////
        // 512 by 256 division.
        ///////////////////////////////////////////////
        
        // Handle overflow, the result must be < 2**256
        require(prod1 < denominator);

        // Make division exact by subtracting the remainder from [prod1 prod0]
        // Compute remainder using mulmod
        // Note mulmod(_, _, 0) == 0
        uint256 remainder;
        assembly {
            remainder := mulmod(a, b, denominator)
        }
        // Subtract 256 bit number from 512 bit number
        assembly {
            prod1 := sub(prod1, gt(remainder, prod0))
            prod0 := sub(prod0, remainder)
        }
        
        // Factor powers of two out of denominator
        // Compute largest power of two divisor of denominator.
        // Always >= 1 unless denominator is zero, then twos is zero.
        uint256 twos = -denominator & denominator;
        // Divide denominator by power of two
        assembly {
            denominator := div(denominator, twos)
        }
        
        // Divide [prod1 prod0] by the factors of two
        assembly {
            prod0 := div(prod0, twos)
        }
        // Shift in bits from prod1 into prod0. For this we need
        // to flip `twos` such that it is 2**256 / twos.
        // If twos is zero, then it becomes one
        assembly {
            twos := add(div(sub(0, twos), twos), 1)
        }
        prod0 |= prod1 * twos;
        
        // Invert denominator mod 2**256
        // Now that denominator is an odd number, it has an inverse
        // modulo 2**256 such that denominator * inv = 1 mod 2**256.
        // Compute the inverse by starting with a seed that is correct
        // correct for four bits. That is, denominator * inv = 1 mod 2**4
        // If denominator is zero the inverse starts with 2
        uint256 inv = 3 * denominator ^ 2;
        // Now use Newton-Raphson itteration to improve the precision.
        // Thanks to Hensel's lifting lemma, this also works in modular
        // arithmetic, doubling the correct bits in each step.
        inv *= 2 - denominator * inv; // inverse mod 2**8
        inv *= 2 - denominator * inv; // inverse mod 2**16
        inv *= 2 - denominator * inv; // inverse mod 2**32
        inv *= 2 - denominator * inv; // inverse mod 2**64
        inv *= 2 - denominator * inv; // inverse mod 2**128
        inv *= 2 - denominator * inv; // inverse mod 2**256
        // If denominator is zero, inv is now 128
        
        // Because the division is now exact we can divide by multiplying
        // with the modular inverse of denominator. This will give us the
        // correct result modulo 2**256. Since the precoditions guarantee
        // that the outcome is less than 2**256, this is the final result.
        // We don't need to compute the high bits of the result and prod1
        // is no longer required.
        result = prod0 * inv;
        return result;
    }
}

A slight variant of the above was briefly considered for inclusion in 0x protocol v2. It turns out that Mikhail Vladimirov independently put the pieces together and came up with the same strategy.

Remco Bloemen
Math & Engineering
https://2π.com