Don't store right end of the interval in uniform_integer_distribution

This commit is contained in:
Martin Hořeňovský 2023-12-21 18:39:18 +01:00
parent 3acb8b30f1
commit 680064d391
No known key found for this signature in database
GPG Key ID: DE48307B8B0D381A

View File

@ -46,12 +46,12 @@ class uniform_integer_distribution {
using UnsignedIntegerType = Detail::make_unsigned_t<IntegerType>;
// We store the left range bound converted to internal representation,
// because it will be used in computation in the () operator.
// Only the left bound is stored, and we store it converted to its
// unsigned image. This avoids having to do the conversions inside
// the operator(), at the cost of having to do the conversion in
// the a() getter. The right bound is only needed in the b() getter,
// so we recompute it there from other stored data.
UnsignedIntegerType m_a;
// After initialization, right bound is only used for the b() getter,
// so we keep it in the original type.
IntegerType m_b;
// How many different values are there in [a, b]. a == b => 1, can be 0 for distribution over all values in the type.
UnsignedIntegerType m_ab_distance;
@ -64,11 +64,10 @@ class uniform_integer_distribution {
// distribution will be reused many times and this is an optimization.
UnsignedIntegerType m_rejection_threshold = 0;
// Assumes m_b and m_a are already filled
UnsignedIntegerType computeDistance() const {
// This overflows and returns 0 if ua == 0 and ub == TYPE_MAX.
UnsignedIntegerType computeDistance(IntegerType a, IntegerType b) const {
// This overflows and returns 0 if a == 0 and b == TYPE_MAX.
// We handle that later when generating the number.
return transposeTo(m_b) - m_a + 1;
return transposeTo(b) - transposeTo(a) + 1;
}
static UnsignedIntegerType computeRejectionThreshold(UnsignedIntegerType ab_distance) {
@ -92,8 +91,7 @@ public:
uniform_integer_distribution( IntegerType a, IntegerType b ):
m_a( transposeTo(a) ),
m_b( b ),
m_ab_distance( computeDistance() ),
m_ab_distance( computeDistance(a, b) ),
m_rejection_threshold( computeRejectionThreshold(m_ab_distance) ) {
assert( a <= b );
}
@ -118,7 +116,7 @@ public:
}
result_type a() const { return transposeBack(m_a); }
result_type b() const { return m_b; }
result_type b() const { return transposeBack(m_ab_distance + m_a - 1); }
};
} // end namespace Catch