mirror of
https://github.com/catchorg/Catch2.git
synced 2025-01-22 00:43:28 +01:00
Optimize 64x64 extended multiplication implementation
Now we use intrinsics when possible, and fallback to optimized implementation in portable C++. The difference is about 4x when we can use intrinsics and about 2x when we cannot. This should speed up our Lemire's algorithm implementation nicely.
This commit is contained in:
parent
f181de9df4
commit
d99eb8bec8
@ -14,6 +14,32 @@
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
// Note: We use the usual enable-disable-autodetect dance here even though
|
||||
// we do not support these in CMake configuration options (yet?).
|
||||
// It is highly unlikely that we will need to make these actually
|
||||
// user-configurable, but this will make it simpler if weend up needing
|
||||
// it, and it provides an escape hatch to the users who need it.
|
||||
#if defined( __SIZEOF_INT128__ )
|
||||
# define CATCH_CONFIG_INTERNAL_UINT128
|
||||
#elif defined( _MSC_VER ) && ( defined( _WIN64 ) || defined( _M_ARM64 ) )
|
||||
# define CATCH_CONFIG_INTERNAL_MSVC_UMUL128
|
||||
#endif
|
||||
|
||||
#if defined( CATCH_CONFIG_INTERNAL_UINT128 ) && \
|
||||
!defined( CATCH_CONFIG_NO_UINT128 ) && \
|
||||
!defined( CATCH_CONFIG_UINT128 )
|
||||
#define CATCH_CONFIG_UINT128
|
||||
#endif
|
||||
|
||||
#if defined( CATCH_CONFIG_INTERNAL_MSVC_UMUL128 ) && \
|
||||
!defined( CATCH_CONFIG_NO_MSVC_UMUL128 ) && \
|
||||
!defined( CATCH_CONFIG_MSVC_UMUL128 )
|
||||
# define CATCH_CONFIG_MSVC_UMUL128
|
||||
# include <intrin.h>
|
||||
# pragma intrinsic( _umul128 )
|
||||
#endif
|
||||
|
||||
|
||||
namespace Catch {
|
||||
namespace Detail {
|
||||
|
||||
@ -46,59 +72,52 @@ namespace Catch {
|
||||
}
|
||||
};
|
||||
|
||||
// Returns 128 bit result of multiplying lhs and rhs
|
||||
/**
|
||||
* Returns 128 bit result of lhs * rhs using portable C++ code
|
||||
*
|
||||
* This implementation is almost twice as fast as naive long multiplication,
|
||||
* and unlike intrinsic-based approach, it supports constexpr evaluation.
|
||||
*/
|
||||
constexpr ExtendedMultResult<std::uint64_t>
|
||||
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
|
||||
// We use the simple long multiplication approach for
|
||||
// correctness, we can use platform specific builtins
|
||||
// for performance later.
|
||||
|
||||
// Split the lhs and rhs into two 32bit "digits", so that we can
|
||||
// do 64 bit arithmetic to handle carry bits.
|
||||
// 32b 32b 32b 32b
|
||||
// lhs L1 L2
|
||||
// * rhs R1 R2
|
||||
// ------------------------
|
||||
// | R2 * L2 |
|
||||
// | R2 * L1 |
|
||||
// | R1 * L2 |
|
||||
// | R1 * L1 |
|
||||
// -------------------------
|
||||
// | a | b | c | d |
|
||||
|
||||
extendedMultPortable(std::uint64_t lhs, std::uint64_t rhs) {
|
||||
#define CarryBits( x ) ( x >> 32 )
|
||||
#define Digits( x ) ( x & 0xFF'FF'FF'FF )
|
||||
std::uint64_t lhs_low = Digits( lhs );
|
||||
std::uint64_t rhs_low = Digits( rhs );
|
||||
std::uint64_t low_low = ( lhs_low * rhs_low );
|
||||
std::uint64_t high_high = CarryBits( lhs ) * CarryBits( rhs );
|
||||
|
||||
auto r2l2 = Digits( rhs ) * Digits( lhs );
|
||||
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
|
||||
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
|
||||
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );
|
||||
|
||||
// Sum to columns first
|
||||
auto d = Digits( r2l2 );
|
||||
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
|
||||
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
|
||||
auto a = CarryBits( r1l1 );
|
||||
|
||||
// Propagate carries between columns
|
||||
c += CarryBits( d );
|
||||
b += CarryBits( c );
|
||||
a += CarryBits( b );
|
||||
|
||||
// Remove the used carries
|
||||
c = Digits( c );
|
||||
b = Digits( b );
|
||||
a = Digits( a );
|
||||
// We add in carry bits from low-low already
|
||||
std::uint64_t high_low =
|
||||
( CarryBits( lhs ) * rhs_low ) + CarryBits( low_low );
|
||||
// Note that we can add only low bits from high_low, to avoid
|
||||
// overflow with large inputs
|
||||
std::uint64_t low_high =
|
||||
( lhs_low * CarryBits( rhs ) ) + Digits( high_low );
|
||||
|
||||
return { high_high + CarryBits( high_low ) + CarryBits( low_high ),
|
||||
( low_high << 32 ) | Digits( low_low ) };
|
||||
#undef CarryBits
|
||||
#undef Digits
|
||||
|
||||
return {
|
||||
a << 32 | b, // upper 64 bits
|
||||
c << 32 | d // lower 64 bits
|
||||
};
|
||||
}
|
||||
|
||||
//! Returns 128 bit result of lhs * rhs
|
||||
inline ExtendedMultResult<std::uint64_t>
|
||||
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
|
||||
#if defined( CATCH_CONFIG_UINT128 )
|
||||
auto result = __uint128_t( lhs ) * __uint128_t( rhs );
|
||||
return { static_cast<std::uint64_t>( result >> 64 ),
|
||||
static_cast<std::uint64_t>( result ) };
|
||||
#elif defined( CATCH_CONFIG_MSVC_UMUL128 )
|
||||
std::uint64_t high;
|
||||
std::uint64_t low = _umul128( lhs, rhs, &high );
|
||||
return { high, low };
|
||||
#else
|
||||
return extendedMultPortable( lhs, rhs );
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
template <typename UInt>
|
||||
constexpr ExtendedMultResult<UInt> extendedMult( UInt lhs, UInt rhs ) {
|
||||
static_assert( std::is_unsigned<UInt>::value,
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/internal/catch_random_integer_helpers.hpp>
|
||||
#include <random>
|
||||
|
||||
namespace {
|
||||
template <typename Int>
|
||||
@ -20,6 +21,58 @@ namespace {
|
||||
CHECK( extendedMult( b, a ) ==
|
||||
ExtendedMultResult<Int>{ upper_result, lower_result } );
|
||||
}
|
||||
|
||||
// Simple (and slow) implmentation of extended multiplication for tests
|
||||
constexpr Catch::Detail::ExtendedMultResult<std::uint64_t>
|
||||
extendedMultNaive( std::uint64_t lhs, std::uint64_t rhs ) {
|
||||
// This is a simple long multiplication, where we split lhs and rhs
|
||||
// into two 32-bit "digits", so that we can do ops with carry in 64-bits.
|
||||
//
|
||||
// 32b 32b 32b 32b
|
||||
// lhs L1 L2
|
||||
// * rhs R1 R2
|
||||
// ------------------------
|
||||
// | R2 * L2 |
|
||||
// | R2 * L1 |
|
||||
// | R1 * L2 |
|
||||
// | R1 * L1 |
|
||||
// -------------------------
|
||||
// | a | b | c | d |
|
||||
|
||||
#define CarryBits( x ) ( x >> 32 )
|
||||
#define Digits( x ) ( x & 0xFF'FF'FF'FF )
|
||||
|
||||
auto r2l2 = Digits( rhs ) * Digits( lhs );
|
||||
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
|
||||
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
|
||||
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );
|
||||
|
||||
// Sum to columns first
|
||||
auto d = Digits( r2l2 );
|
||||
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
|
||||
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
|
||||
auto a = CarryBits( r1l1 );
|
||||
|
||||
// Propagate carries between columns
|
||||
c += CarryBits( d );
|
||||
b += CarryBits( c );
|
||||
a += CarryBits( b );
|
||||
|
||||
// Remove the used carries
|
||||
c = Digits( c );
|
||||
b = Digits( b );
|
||||
a = Digits( a );
|
||||
|
||||
#undef CarryBits
|
||||
#undef Digits
|
||||
|
||||
return {
|
||||
a << 32 | b, // upper 64 bits
|
||||
c << 32 | d // lower 64 bits
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
|
||||
@ -62,6 +115,27 @@ TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
|
||||
0xdf44'2d22'ce48'59b9 );
|
||||
}
|
||||
|
||||
TEST_CASE("extendedMult 64x64 - all implementations", "[integer][approvals]") {
|
||||
using Catch::Detail::extendedMult;
|
||||
using Catch::Detail::extendedMultPortable;
|
||||
using Catch::Detail::fillBitsFrom;
|
||||
|
||||
std::random_device rng;
|
||||
for (size_t i = 0; i < 100; ++i) {
|
||||
auto a = fillBitsFrom<std::uint64_t>( rng );
|
||||
auto b = fillBitsFrom<std::uint64_t>( rng );
|
||||
CAPTURE( a, b );
|
||||
|
||||
auto naive_ab = extendedMultNaive( a, b );
|
||||
|
||||
REQUIRE( naive_ab == extendedMultNaive( b, a ) );
|
||||
REQUIRE( naive_ab == extendedMultPortable( a, b ) );
|
||||
REQUIRE( naive_ab == extendedMultPortable( b, a ) );
|
||||
REQUIRE( naive_ab == extendedMult( a, b ) );
|
||||
REQUIRE( naive_ab == extendedMult( b, a ) );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) {
|
||||
using Catch::Detail::SizedUnsignedType_t;
|
||||
using Catch::Detail::DoubleWidthUnsignedType_t;
|
||||
|
Loading…
Reference in New Issue
Block a user