mirror of
https://github.com/catchorg/Catch2.git
synced 2025-08-01 21:05:39 +02:00
Optimize 64x64 extended multiplication implementation
Now we use intrinsics when possible, and fallback to optimized implementation in portable C++. The difference is about 4x when we can use intrinsics and about 2x when we cannot. This should speed up our Lemire's algorithm implementation nicely.
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/internal/catch_random_integer_helpers.hpp>
|
||||
#include <random>
|
||||
|
||||
namespace {
|
||||
template <typename Int>
|
||||
@@ -20,6 +21,58 @@ namespace {
|
||||
CHECK( extendedMult( b, a ) ==
|
||||
ExtendedMultResult<Int>{ upper_result, lower_result } );
|
||||
}
|
||||
|
||||
// Simple (and slow) implmentation of extended multiplication for tests
|
||||
constexpr Catch::Detail::ExtendedMultResult<std::uint64_t>
|
||||
extendedMultNaive( std::uint64_t lhs, std::uint64_t rhs ) {
|
||||
// This is a simple long multiplication, where we split lhs and rhs
|
||||
// into two 32-bit "digits", so that we can do ops with carry in 64-bits.
|
||||
//
|
||||
// 32b 32b 32b 32b
|
||||
// lhs L1 L2
|
||||
// * rhs R1 R2
|
||||
// ------------------------
|
||||
// | R2 * L2 |
|
||||
// | R2 * L1 |
|
||||
// | R1 * L2 |
|
||||
// | R1 * L1 |
|
||||
// -------------------------
|
||||
// | a | b | c | d |
|
||||
|
||||
#define CarryBits( x ) ( x >> 32 )
|
||||
#define Digits( x ) ( x & 0xFF'FF'FF'FF )
|
||||
|
||||
auto r2l2 = Digits( rhs ) * Digits( lhs );
|
||||
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
|
||||
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
|
||||
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );
|
||||
|
||||
// Sum to columns first
|
||||
auto d = Digits( r2l2 );
|
||||
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
|
||||
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
|
||||
auto a = CarryBits( r1l1 );
|
||||
|
||||
// Propagate carries between columns
|
||||
c += CarryBits( d );
|
||||
b += CarryBits( c );
|
||||
a += CarryBits( b );
|
||||
|
||||
// Remove the used carries
|
||||
c = Digits( c );
|
||||
b = Digits( b );
|
||||
a = Digits( a );
|
||||
|
||||
#undef CarryBits
|
||||
#undef Digits
|
||||
|
||||
return {
|
||||
a << 32 | b, // upper 64 bits
|
||||
c << 32 | d // lower 64 bits
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
|
||||
@@ -62,6 +115,27 @@ TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
|
||||
0xdf44'2d22'ce48'59b9 );
|
||||
}
|
||||
|
||||
TEST_CASE("extendedMult 64x64 - all implementations", "[integer][approvals]") {
|
||||
using Catch::Detail::extendedMult;
|
||||
using Catch::Detail::extendedMultPortable;
|
||||
using Catch::Detail::fillBitsFrom;
|
||||
|
||||
std::random_device rng;
|
||||
for (size_t i = 0; i < 100; ++i) {
|
||||
auto a = fillBitsFrom<std::uint64_t>( rng );
|
||||
auto b = fillBitsFrom<std::uint64_t>( rng );
|
||||
CAPTURE( a, b );
|
||||
|
||||
auto naive_ab = extendedMultNaive( a, b );
|
||||
|
||||
REQUIRE( naive_ab == extendedMultNaive( b, a ) );
|
||||
REQUIRE( naive_ab == extendedMultPortable( a, b ) );
|
||||
REQUIRE( naive_ab == extendedMultPortable( b, a ) );
|
||||
REQUIRE( naive_ab == extendedMult( a, b ) );
|
||||
REQUIRE( naive_ab == extendedMult( b, a ) );
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) {
|
||||
using Catch::Detail::SizedUnsignedType_t;
|
||||
using Catch::Detail::DoubleWidthUnsignedType_t;
|
||||
|
Reference in New Issue
Block a user