Optimize 64x64 extended multiplication implementation

Now we use intrinsics when possible, and fallback to optimized
implementation in portable C++. The difference is about 4x when
we can use intrinsics and about 2x when we cannot.

This should speed up our Lemire's algorithm implementation nicely.
This commit is contained in:
Martin Hořeňovský 2024-04-02 18:09:34 +02:00
parent f181de9df4
commit d99eb8bec8
No known key found for this signature in database
GPG Key ID: DE48307B8B0D381A
2 changed files with 137 additions and 44 deletions

View File

@ -14,6 +14,32 @@
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
// Note: We use the usual enable-disable-autodetect dance here even though
// we do not support these in CMake configuration options (yet?).
// It is highly unlikely that we will need to make these actually
// user-configurable, but this will make it simpler if weend up needing
// it, and it provides an escape hatch to the users who need it.
#if defined( __SIZEOF_INT128__ )
# define CATCH_CONFIG_INTERNAL_UINT128
#elif defined( _MSC_VER ) && ( defined( _WIN64 ) || defined( _M_ARM64 ) )
# define CATCH_CONFIG_INTERNAL_MSVC_UMUL128
#endif
#if defined( CATCH_CONFIG_INTERNAL_UINT128 ) && \
!defined( CATCH_CONFIG_NO_UINT128 ) && \
!defined( CATCH_CONFIG_UINT128 )
#define CATCH_CONFIG_UINT128
#endif
#if defined( CATCH_CONFIG_INTERNAL_MSVC_UMUL128 ) && \
!defined( CATCH_CONFIG_NO_MSVC_UMUL128 ) && \
!defined( CATCH_CONFIG_MSVC_UMUL128 )
# define CATCH_CONFIG_MSVC_UMUL128
# include <intrin.h>
# pragma intrinsic( _umul128 )
#endif
namespace Catch { namespace Catch {
namespace Detail { namespace Detail {
@ -46,59 +72,52 @@ namespace Catch {
} }
}; };
// Returns 128 bit result of multiplying lhs and rhs /**
* Returns 128 bit result of lhs * rhs using portable C++ code
*
* This implementation is almost twice as fast as naive long multiplication,
* and unlike intrinsic-based approach, it supports constexpr evaluation.
*/
constexpr ExtendedMultResult<std::uint64_t> constexpr ExtendedMultResult<std::uint64_t>
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) { extendedMultPortable(std::uint64_t lhs, std::uint64_t rhs) {
// We use the simple long multiplication approach for
// correctness, we can use platform specific builtins
// for performance later.
// Split the lhs and rhs into two 32bit "digits", so that we can
// do 64 bit arithmetic to handle carry bits.
// 32b 32b 32b 32b
// lhs L1 L2
// * rhs R1 R2
// ------------------------
// | R2 * L2 |
// | R2 * L1 |
// | R1 * L2 |
// | R1 * L1 |
// -------------------------
// | a | b | c | d |
#define CarryBits( x ) ( x >> 32 ) #define CarryBits( x ) ( x >> 32 )
#define Digits( x ) ( x & 0xFF'FF'FF'FF ) #define Digits( x ) ( x & 0xFF'FF'FF'FF )
std::uint64_t lhs_low = Digits( lhs );
std::uint64_t rhs_low = Digits( rhs );
std::uint64_t low_low = ( lhs_low * rhs_low );
std::uint64_t high_high = CarryBits( lhs ) * CarryBits( rhs );
auto r2l2 = Digits( rhs ) * Digits( lhs ); // We add in carry bits from low-low already
auto r2l1 = Digits( rhs ) * CarryBits( lhs ); std::uint64_t high_low =
auto r1l2 = CarryBits( rhs ) * Digits( lhs ); ( CarryBits( lhs ) * rhs_low ) + CarryBits( low_low );
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs ); // Note that we can add only low bits from high_low, to avoid
// overflow with large inputs
// Sum to columns first std::uint64_t low_high =
auto d = Digits( r2l2 ); ( lhs_low * CarryBits( rhs ) ) + Digits( high_low );
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
auto a = CarryBits( r1l1 );
// Propagate carries between columns
c += CarryBits( d );
b += CarryBits( c );
a += CarryBits( b );
// Remove the used carries
c = Digits( c );
b = Digits( b );
a = Digits( a );
return { high_high + CarryBits( high_low ) + CarryBits( low_high ),
( low_high << 32 ) | Digits( low_low ) };
#undef CarryBits #undef CarryBits
#undef Digits #undef Digits
return {
a << 32 | b, // upper 64 bits
c << 32 | d // lower 64 bits
};
} }
//! Returns 128 bit result of lhs * rhs
inline ExtendedMultResult<std::uint64_t>
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
#if defined( CATCH_CONFIG_UINT128 )
auto result = __uint128_t( lhs ) * __uint128_t( rhs );
return { static_cast<std::uint64_t>( result >> 64 ),
static_cast<std::uint64_t>( result ) };
#elif defined( CATCH_CONFIG_MSVC_UMUL128 )
std::uint64_t high;
std::uint64_t low = _umul128( lhs, rhs, &high );
return { high, low };
#else
return extendedMultPortable( lhs, rhs );
#endif
}
template <typename UInt> template <typename UInt>
constexpr ExtendedMultResult<UInt> extendedMult( UInt lhs, UInt rhs ) { constexpr ExtendedMultResult<UInt> extendedMult( UInt lhs, UInt rhs ) {
static_assert( std::is_unsigned<UInt>::value, static_assert( std::is_unsigned<UInt>::value,

View File

@ -8,6 +8,7 @@
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <catch2/internal/catch_random_integer_helpers.hpp> #include <catch2/internal/catch_random_integer_helpers.hpp>
#include <random>
namespace { namespace {
template <typename Int> template <typename Int>
@ -20,6 +21,58 @@ namespace {
CHECK( extendedMult( b, a ) == CHECK( extendedMult( b, a ) ==
ExtendedMultResult<Int>{ upper_result, lower_result } ); ExtendedMultResult<Int>{ upper_result, lower_result } );
} }
// Simple (and slow) implmentation of extended multiplication for tests
constexpr Catch::Detail::ExtendedMultResult<std::uint64_t>
extendedMultNaive( std::uint64_t lhs, std::uint64_t rhs ) {
// This is a simple long multiplication, where we split lhs and rhs
// into two 32-bit "digits", so that we can do ops with carry in 64-bits.
//
// 32b 32b 32b 32b
// lhs L1 L2
// * rhs R1 R2
// ------------------------
// | R2 * L2 |
// | R2 * L1 |
// | R1 * L2 |
// | R1 * L1 |
// -------------------------
// | a | b | c | d |
#define CarryBits( x ) ( x >> 32 )
#define Digits( x ) ( x & 0xFF'FF'FF'FF )
auto r2l2 = Digits( rhs ) * Digits( lhs );
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );
// Sum to columns first
auto d = Digits( r2l2 );
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
auto a = CarryBits( r1l1 );
// Propagate carries between columns
c += CarryBits( d );
b += CarryBits( c );
a += CarryBits( b );
// Remove the used carries
c = Digits( c );
b = Digits( b );
a = Digits( a );
#undef CarryBits
#undef Digits
return {
a << 32 | b, // upper 64 bits
c << 32 | d // lower 64 bits
};
}
} // namespace } // namespace
TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) { TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
@ -62,6 +115,27 @@ TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
0xdf44'2d22'ce48'59b9 ); 0xdf44'2d22'ce48'59b9 );
} }
TEST_CASE("extendedMult 64x64 - all implementations", "[integer][approvals]") {
using Catch::Detail::extendedMult;
using Catch::Detail::extendedMultPortable;
using Catch::Detail::fillBitsFrom;
std::random_device rng;
for (size_t i = 0; i < 100; ++i) {
auto a = fillBitsFrom<std::uint64_t>( rng );
auto b = fillBitsFrom<std::uint64_t>( rng );
CAPTURE( a, b );
auto naive_ab = extendedMultNaive( a, b );
REQUIRE( naive_ab == extendedMultNaive( b, a ) );
REQUIRE( naive_ab == extendedMultPortable( a, b ) );
REQUIRE( naive_ab == extendedMultPortable( b, a ) );
REQUIRE( naive_ab == extendedMult( a, b ) );
REQUIRE( naive_ab == extendedMult( b, a ) );
}
}
TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) { TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) {
using Catch::Detail::SizedUnsignedType_t; using Catch::Detail::SizedUnsignedType_t;
using Catch::Detail::DoubleWidthUnsignedType_t; using Catch::Detail::DoubleWidthUnsignedType_t;