mirror of
https://github.com/catchorg/Catch2.git
synced 2024-11-22 05:16:10 +01:00
Optimize 64x64 extended multiplication implementation
Now we use intrinsics when possible, and fallback to optimized implementation in portable C++. The difference is about 4x when we can use intrinsics and about 2x when we cannot. This should speed up our Lemire's algorithm implementation nicely.
This commit is contained in:
parent
f181de9df4
commit
d99eb8bec8
@ -14,6 +14,32 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
// Note: We use the usual enable-disable-autodetect dance here even though
|
||||||
|
// we do not support these in CMake configuration options (yet?).
|
||||||
|
// It is highly unlikely that we will need to make these actually
|
||||||
|
// user-configurable, but this will make it simpler if weend up needing
|
||||||
|
// it, and it provides an escape hatch to the users who need it.
|
||||||
|
#if defined( __SIZEOF_INT128__ )
|
||||||
|
# define CATCH_CONFIG_INTERNAL_UINT128
|
||||||
|
#elif defined( _MSC_VER ) && ( defined( _WIN64 ) || defined( _M_ARM64 ) )
|
||||||
|
# define CATCH_CONFIG_INTERNAL_MSVC_UMUL128
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined( CATCH_CONFIG_INTERNAL_UINT128 ) && \
|
||||||
|
!defined( CATCH_CONFIG_NO_UINT128 ) && \
|
||||||
|
!defined( CATCH_CONFIG_UINT128 )
|
||||||
|
#define CATCH_CONFIG_UINT128
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined( CATCH_CONFIG_INTERNAL_MSVC_UMUL128 ) && \
|
||||||
|
!defined( CATCH_CONFIG_NO_MSVC_UMUL128 ) && \
|
||||||
|
!defined( CATCH_CONFIG_MSVC_UMUL128 )
|
||||||
|
# define CATCH_CONFIG_MSVC_UMUL128
|
||||||
|
# include <intrin.h>
|
||||||
|
# pragma intrinsic( _umul128 )
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
namespace Catch {
|
namespace Catch {
|
||||||
namespace Detail {
|
namespace Detail {
|
||||||
|
|
||||||
@ -46,59 +72,52 @@ namespace Catch {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns 128 bit result of multiplying lhs and rhs
|
/**
|
||||||
|
* Returns 128 bit result of lhs * rhs using portable C++ code
|
||||||
|
*
|
||||||
|
* This implementation is almost twice as fast as naive long multiplication,
|
||||||
|
* and unlike intrinsic-based approach, it supports constexpr evaluation.
|
||||||
|
*/
|
||||||
constexpr ExtendedMultResult<std::uint64_t>
|
constexpr ExtendedMultResult<std::uint64_t>
|
||||||
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
|
extendedMultPortable(std::uint64_t lhs, std::uint64_t rhs) {
|
||||||
// We use the simple long multiplication approach for
|
|
||||||
// correctness, we can use platform specific builtins
|
|
||||||
// for performance later.
|
|
||||||
|
|
||||||
// Split the lhs and rhs into two 32bit "digits", so that we can
|
|
||||||
// do 64 bit arithmetic to handle carry bits.
|
|
||||||
// 32b 32b 32b 32b
|
|
||||||
// lhs L1 L2
|
|
||||||
// * rhs R1 R2
|
|
||||||
// ------------------------
|
|
||||||
// | R2 * L2 |
|
|
||||||
// | R2 * L1 |
|
|
||||||
// | R1 * L2 |
|
|
||||||
// | R1 * L1 |
|
|
||||||
// -------------------------
|
|
||||||
// | a | b | c | d |
|
|
||||||
|
|
||||||
#define CarryBits( x ) ( x >> 32 )
|
#define CarryBits( x ) ( x >> 32 )
|
||||||
#define Digits( x ) ( x & 0xFF'FF'FF'FF )
|
#define Digits( x ) ( x & 0xFF'FF'FF'FF )
|
||||||
|
std::uint64_t lhs_low = Digits( lhs );
|
||||||
|
std::uint64_t rhs_low = Digits( rhs );
|
||||||
|
std::uint64_t low_low = ( lhs_low * rhs_low );
|
||||||
|
std::uint64_t high_high = CarryBits( lhs ) * CarryBits( rhs );
|
||||||
|
|
||||||
auto r2l2 = Digits( rhs ) * Digits( lhs );
|
// We add in carry bits from low-low already
|
||||||
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
|
std::uint64_t high_low =
|
||||||
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
|
( CarryBits( lhs ) * rhs_low ) + CarryBits( low_low );
|
||||||
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );
|
// Note that we can add only low bits from high_low, to avoid
|
||||||
|
// overflow with large inputs
|
||||||
// Sum to columns first
|
std::uint64_t low_high =
|
||||||
auto d = Digits( r2l2 );
|
( lhs_low * CarryBits( rhs ) ) + Digits( high_low );
|
||||||
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
|
|
||||||
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
|
|
||||||
auto a = CarryBits( r1l1 );
|
|
||||||
|
|
||||||
// Propagate carries between columns
|
|
||||||
c += CarryBits( d );
|
|
||||||
b += CarryBits( c );
|
|
||||||
a += CarryBits( b );
|
|
||||||
|
|
||||||
// Remove the used carries
|
|
||||||
c = Digits( c );
|
|
||||||
b = Digits( b );
|
|
||||||
a = Digits( a );
|
|
||||||
|
|
||||||
|
return { high_high + CarryBits( high_low ) + CarryBits( low_high ),
|
||||||
|
( low_high << 32 ) | Digits( low_low ) };
|
||||||
#undef CarryBits
|
#undef CarryBits
|
||||||
#undef Digits
|
#undef Digits
|
||||||
|
|
||||||
return {
|
|
||||||
a << 32 | b, // upper 64 bits
|
|
||||||
c << 32 | d // lower 64 bits
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//! Returns 128 bit result of lhs * rhs
|
||||||
|
inline ExtendedMultResult<std::uint64_t>
|
||||||
|
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
|
||||||
|
#if defined( CATCH_CONFIG_UINT128 )
|
||||||
|
auto result = __uint128_t( lhs ) * __uint128_t( rhs );
|
||||||
|
return { static_cast<std::uint64_t>( result >> 64 ),
|
||||||
|
static_cast<std::uint64_t>( result ) };
|
||||||
|
#elif defined( CATCH_CONFIG_MSVC_UMUL128 )
|
||||||
|
std::uint64_t high;
|
||||||
|
std::uint64_t low = _umul128( lhs, rhs, &high );
|
||||||
|
return { high, low };
|
||||||
|
#else
|
||||||
|
return extendedMultPortable( lhs, rhs );
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename UInt>
|
template <typename UInt>
|
||||||
constexpr ExtendedMultResult<UInt> extendedMult( UInt lhs, UInt rhs ) {
|
constexpr ExtendedMultResult<UInt> extendedMult( UInt lhs, UInt rhs ) {
|
||||||
static_assert( std::is_unsigned<UInt>::value,
|
static_assert( std::is_unsigned<UInt>::value,
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include <catch2/catch_test_macros.hpp>
|
#include <catch2/catch_test_macros.hpp>
|
||||||
#include <catch2/internal/catch_random_integer_helpers.hpp>
|
#include <catch2/internal/catch_random_integer_helpers.hpp>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename Int>
|
template <typename Int>
|
||||||
@ -20,6 +21,58 @@ namespace {
|
|||||||
CHECK( extendedMult( b, a ) ==
|
CHECK( extendedMult( b, a ) ==
|
||||||
ExtendedMultResult<Int>{ upper_result, lower_result } );
|
ExtendedMultResult<Int>{ upper_result, lower_result } );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simple (and slow) implmentation of extended multiplication for tests
|
||||||
|
constexpr Catch::Detail::ExtendedMultResult<std::uint64_t>
|
||||||
|
extendedMultNaive( std::uint64_t lhs, std::uint64_t rhs ) {
|
||||||
|
// This is a simple long multiplication, where we split lhs and rhs
|
||||||
|
// into two 32-bit "digits", so that we can do ops with carry in 64-bits.
|
||||||
|
//
|
||||||
|
// 32b 32b 32b 32b
|
||||||
|
// lhs L1 L2
|
||||||
|
// * rhs R1 R2
|
||||||
|
// ------------------------
|
||||||
|
// | R2 * L2 |
|
||||||
|
// | R2 * L1 |
|
||||||
|
// | R1 * L2 |
|
||||||
|
// | R1 * L1 |
|
||||||
|
// -------------------------
|
||||||
|
// | a | b | c | d |
|
||||||
|
|
||||||
|
#define CarryBits( x ) ( x >> 32 )
|
||||||
|
#define Digits( x ) ( x & 0xFF'FF'FF'FF )
|
||||||
|
|
||||||
|
auto r2l2 = Digits( rhs ) * Digits( lhs );
|
||||||
|
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
|
||||||
|
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
|
||||||
|
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );
|
||||||
|
|
||||||
|
// Sum to columns first
|
||||||
|
auto d = Digits( r2l2 );
|
||||||
|
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
|
||||||
|
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
|
||||||
|
auto a = CarryBits( r1l1 );
|
||||||
|
|
||||||
|
// Propagate carries between columns
|
||||||
|
c += CarryBits( d );
|
||||||
|
b += CarryBits( c );
|
||||||
|
a += CarryBits( b );
|
||||||
|
|
||||||
|
// Remove the used carries
|
||||||
|
c = Digits( c );
|
||||||
|
b = Digits( b );
|
||||||
|
a = Digits( a );
|
||||||
|
|
||||||
|
#undef CarryBits
|
||||||
|
#undef Digits
|
||||||
|
|
||||||
|
return {
|
||||||
|
a << 32 | b, // upper 64 bits
|
||||||
|
c << 32 | d // lower 64 bits
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
|
TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
|
||||||
@ -62,6 +115,27 @@ TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
|
|||||||
0xdf44'2d22'ce48'59b9 );
|
0xdf44'2d22'ce48'59b9 );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("extendedMult 64x64 - all implementations", "[integer][approvals]") {
|
||||||
|
using Catch::Detail::extendedMult;
|
||||||
|
using Catch::Detail::extendedMultPortable;
|
||||||
|
using Catch::Detail::fillBitsFrom;
|
||||||
|
|
||||||
|
std::random_device rng;
|
||||||
|
for (size_t i = 0; i < 100; ++i) {
|
||||||
|
auto a = fillBitsFrom<std::uint64_t>( rng );
|
||||||
|
auto b = fillBitsFrom<std::uint64_t>( rng );
|
||||||
|
CAPTURE( a, b );
|
||||||
|
|
||||||
|
auto naive_ab = extendedMultNaive( a, b );
|
||||||
|
|
||||||
|
REQUIRE( naive_ab == extendedMultNaive( b, a ) );
|
||||||
|
REQUIRE( naive_ab == extendedMultPortable( a, b ) );
|
||||||
|
REQUIRE( naive_ab == extendedMultPortable( b, a ) );
|
||||||
|
REQUIRE( naive_ab == extendedMult( a, b ) );
|
||||||
|
REQUIRE( naive_ab == extendedMult( b, a ) );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) {
|
TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) {
|
||||||
using Catch::Detail::SizedUnsignedType_t;
|
using Catch::Detail::SizedUnsignedType_t;
|
||||||
using Catch::Detail::DoubleWidthUnsignedType_t;
|
using Catch::Detail::DoubleWidthUnsignedType_t;
|
||||||
|
Loading…
Reference in New Issue
Block a user