Don't store right end of the interval in uniform_integer_distribution

This commit is contained in:
Martin Hořeňovský 2023-12-21 18:39:18 +01:00
parent 3acb8b30f1
commit 680064d391
No known key found for this signature in database
GPG Key ID: DE48307B8B0D381A

View File

@ -46,12 +46,12 @@ class uniform_integer_distribution {
using UnsignedIntegerType = Detail::make_unsigned_t<IntegerType>; using UnsignedIntegerType = Detail::make_unsigned_t<IntegerType>;
// We store the left range bound converted to internal representation, // Only the left bound is stored, and we store it converted to its
// because it will be used in computation in the () operator. // unsigned image. This avoids having to do the conversions inside
// the operator(), at the cost of having to do the conversion in
// the a() getter. The right bound is only needed in the b() getter,
// so we recompute it there from other stored data.
UnsignedIntegerType m_a; UnsignedIntegerType m_a;
// After initialization, right bound is only used for the b() getter,
// so we keep it in the original type.
IntegerType m_b;
// How many different values are there in [a, b]. a == b => 1, can be 0 for distribution over all values in the type. // How many different values are there in [a, b]. a == b => 1, can be 0 for distribution over all values in the type.
UnsignedIntegerType m_ab_distance; UnsignedIntegerType m_ab_distance;
@ -64,11 +64,10 @@ class uniform_integer_distribution {
// distribution will be reused many times and this is an optimization. // distribution will be reused many times and this is an optimization.
UnsignedIntegerType m_rejection_threshold = 0; UnsignedIntegerType m_rejection_threshold = 0;
// Assumes m_b and m_a are already filled UnsignedIntegerType computeDistance(IntegerType a, IntegerType b) const {
UnsignedIntegerType computeDistance() const { // This overflows and returns 0 if a == 0 and b == TYPE_MAX.
// This overflows and returns 0 if ua == 0 and ub == TYPE_MAX.
// We handle that later when generating the number. // We handle that later when generating the number.
return transposeTo(m_b) - m_a + 1; return transposeTo(b) - transposeTo(a) + 1;
} }
static UnsignedIntegerType computeRejectionThreshold(UnsignedIntegerType ab_distance) { static UnsignedIntegerType computeRejectionThreshold(UnsignedIntegerType ab_distance) {
@ -92,8 +91,7 @@ public:
uniform_integer_distribution( IntegerType a, IntegerType b ): uniform_integer_distribution( IntegerType a, IntegerType b ):
m_a( transposeTo(a) ), m_a( transposeTo(a) ),
m_b( b ), m_ab_distance( computeDistance(a, b) ),
m_ab_distance( computeDistance() ),
m_rejection_threshold( computeRejectionThreshold(m_ab_distance) ) { m_rejection_threshold( computeRejectionThreshold(m_ab_distance) ) {
assert( a <= b ); assert( a <= b );
} }
@ -118,7 +116,7 @@ public:
} }
result_type a() const { return transposeBack(m_a); } result_type a() const { return transposeBack(m_a); }
result_type b() const { return m_b; } result_type b() const { return transposeBack(m_ab_distance + m_a - 1); }
}; };
} // end namespace Catch } // end namespace Catch