Add configuration option to make assertions thread-safe

All the previous refactoring to make the assertion fast paths
smaller and faster also allows us to implement the fast paths
just with thread-local and atomic variables, without full mutexes.

However, the performance overhead of thread-safe assertions is
still significant for single threaded usage:

|  slowdown |  Debug  | Release |
|-----------|--------:|--------:|
| fast path |   1.04x |   1.43x |
| slow path |   1.16x |   1.22x |

Thus, we don't make the assertions thread-safe by default, and instead
provide a build-time configuration option that the users can set to get
thread-safe assertions.

This commit is functional, but it still needs some follow-up work:
 * We do not need full seq_cst increments for the atomic counters,
   and using weaker ones can be faster.
 * We brute-force updating the reporter-friendly totals from internal
   atomic counters by doing it everywhere. We should properly trace
   where this is needed instead.
 * Message macros (`INFO`, `UNSCOPED_INFO`, `CAPTURE`, etc) are not
   made thread safe in this commit, but they can be made thread safe
   in the future, by building on top of this work.
 * Add more tests, including with thread-sanitizer, and compiled
   examples to the repository. Right now, these changes have been
   compiled with tsan manually, but these tests are not added to CI.

Closes #2948
This commit is contained in:
Martin Hořeňovský
2025-07-17 22:58:41 +02:00
parent 900a6d5516
commit 2a8a8a7210
12 changed files with 332 additions and 49 deletions

View File

@@ -139,6 +139,7 @@ set(IMPL_HEADERS
${SOURCES_DIR}/internal/catch_test_registry.hpp
${SOURCES_DIR}/internal/catch_test_spec_parser.hpp
${SOURCES_DIR}/internal/catch_textflow.hpp
${SOURCES_DIR}/internal/catch_thread_support.hpp
${SOURCES_DIR}/internal/catch_to_string.hpp
${SOURCES_DIR}/internal/catch_uncaught_exceptions.hpp
${SOURCES_DIR}/internal/catch_uniform_floating_point_distribution.hpp

View File

@@ -121,6 +121,7 @@
#include <catch2/internal/catch_test_registry.hpp>
#include <catch2/internal/catch_test_spec_parser.hpp>
#include <catch2/internal/catch_textflow.hpp>
#include <catch2/internal/catch_thread_support.hpp>
#include <catch2/internal/catch_to_string.hpp>
#include <catch2/internal/catch_uncaught_exceptions.hpp>
#include <catch2/internal/catch_uniform_floating_point_distribution.hpp>

View File

@@ -196,6 +196,14 @@
#endif
#cmakedefine CATCH_CONFIG_EXPERIMENTAL_THREAD_SAFE_ASSERTIONS
#cmakedefine CATCH_CONFIG_NO_EXPERIMENTAL_THREAD_SAFE_ASSERTIONS
#if defined( CATCH_CONFIG_EXPERIMENTAL_THREAD_SAFE_ASSERTIONS ) && \
defined( CATCH_CONFIG_NO_EXPERIMENTAL_THREAD_SAFE_ASSERTIONS )
# error Cannot force EXPERIMENTAL_THREAD_SAFE_ASSERTIONS to both ON and OFF
#endif
// ------
// Simple toggle defines

View File

@@ -165,11 +165,38 @@ namespace Catch {
} // namespace
}
namespace Detail {
// Assertions are owned by the thread that is executing them.
// This allows for lock-free progress in common cases where we
// do not need to send the assertion events to the reporter.
// This also implies that messages are owned by their respective
// threads, and should not be shared across different threads.
//
// For simplicity, we disallow messages in multi-threaded contexts,
// but in the future we can enable them under this logic.
//
// This implies that various pieces of metadata referring to last
// assertion result/source location/message handling, etc
// should also be thread local. For now we just use naked globals
// below, in the future we will want to allocate piece of memory
// from heap, to avoid consuming too much thread-local storage.
// This is used for the "if" part of CHECKED_IF/CHECKED_ELSE
static thread_local bool g_lastAssertionPassed = false;
// Should we clear message scopes before sending off the messages to
// reporter? Set in `assertionPassedFastPath` to avoid doing the full
// clear there for performance reasons.
static thread_local bool g_clearMessageScopes = false;
// This is the source location for last encountered macro. It is
// used to provide the users with more precise location of error
// when an unexpected exception/fatal error happens.
static thread_local SourceLineInfo g_lastKnownLineInfo("DummyLocation", static_cast<size_t>(-1));
}
RunContext::RunContext(IConfig const* _config, IEventListenerPtr&& reporter)
: m_runInfo(_config->name()),
m_config(_config),
m_reporter(CATCH_MOVE(reporter)),
m_lastKnownLineInfo("DummyLocation", static_cast<size_t>(-1)),
m_outputRedirect( makeOutputRedirect( m_reporter->getPreferences().shouldRedirectStdOut ) ),
m_abortAfterXFailedAssertions( m_config->abortAfter() ),
m_reportAssertionStarting( m_reporter->getPreferences().shouldReportAllAssertionStarts ),
@@ -181,10 +208,12 @@ namespace Catch {
}
RunContext::~RunContext() {
updateTotalsFromAtomics();
m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, aborting()));
}
Totals RunContext::runTest(TestCaseHandle const& testCase) {
updateTotalsFromAtomics();
const Totals prevTotals = m_totals;
auto const& testInfo = testCase.getTestCaseInfo();
@@ -239,6 +268,7 @@ namespace Catch {
m_reporter->testCasePartialStarting(testInfo, testRuns);
updateTotalsFromAtomics();
const auto beforeRunTotals = m_totals;
runCurrentTest();
std::string oneRunCout = m_outputRedirect->getStdout();
@@ -247,6 +277,7 @@ namespace Catch {
redirectedCout += oneRunCout;
redirectedCerr += oneRunCerr;
updateTotalsFromAtomics();
const auto singleRunTotals = m_totals.delta(beforeRunTotals);
auto statsForOneRun = TestCaseStats(testInfo, singleRunTotals, CATCH_MOVE(oneRunCout), CATCH_MOVE(oneRunCerr), aborting());
m_reporter->testCasePartialEnded(statsForOneRun, testRuns);
@@ -276,31 +307,35 @@ namespace Catch {
void RunContext::assertionEnded(AssertionResult&& result) {
Detail::g_lastKnownLineInfo = result.m_info.lineInfo;
if (result.getResultType() == ResultWas::Ok) {
m_totals.assertions.passed++;
m_lastAssertionPassed = true;
m_atomicAssertionCount.passed++;
Detail::g_lastAssertionPassed = true;
} else if (result.getResultType() == ResultWas::ExplicitSkip) {
m_totals.assertions.skipped++;
m_lastAssertionPassed = true;
m_atomicAssertionCount.skipped++;
Detail::g_lastAssertionPassed = true;
} else if (!result.succeeded()) {
m_lastAssertionPassed = false;
Detail::g_lastAssertionPassed = false;
if (result.isOk()) {
}
else if( m_activeTestCase->getTestCaseInfo().okToFail() )
m_totals.assertions.failedButOk++;
else if( m_activeTestCase->getTestCaseInfo().okToFail() ) // Read from a shared state established before the threads could start, this is fine
m_atomicAssertionCount.failedButOk++;
else
m_totals.assertions.failed++;
m_atomicAssertionCount.failed++;
}
else {
m_lastAssertionPassed = true;
Detail::g_lastAssertionPassed = true;
}
// From here, we are touching shared state and need mutex.
Detail::LockGuard lock( m_assertionMutex );
{
if ( m_clearMessageScopes ) {
if ( Detail::g_clearMessageScopes ) {
m_messageScopes.clear();
m_clearMessageScopes = false;
Detail::g_clearMessageScopes = false;
}
auto _ = scopedDeactivate( *m_outputRedirect );
updateTotalsFromAtomics();
m_reporter->assertionEnded( AssertionStats( result, m_messages, m_totals ) );
}
@@ -315,6 +350,7 @@ namespace Catch {
void RunContext::notifyAssertionStarted( AssertionInfo const& info ) {
if (m_reportAssertionStarting) {
Detail::LockGuard lock( m_assertionMutex );
auto _ = scopedDeactivate( *m_outputRedirect );
m_reporter->assertionStarting( info );
}
@@ -333,13 +369,14 @@ namespace Catch {
m_activeSections.push_back(&sectionTracker);
SectionInfo sectionInfo( sectionLineInfo, static_cast<std::string>(sectionName) );
m_lastKnownLineInfo = sectionLineInfo;
Detail::g_lastKnownLineInfo = sectionLineInfo;
{
auto _ = scopedDeactivate( *m_outputRedirect );
m_reporter->sectionStarting( sectionInfo );
}
updateTotalsFromAtomics();
assertions = m_totals.assertions;
return true;
@@ -347,12 +384,11 @@ namespace Catch {
IGeneratorTracker*
RunContext::acquireGeneratorTracker( StringRef generatorName,
SourceLineInfo const& lineInfo ) {
using namespace Generators;
GeneratorTracker* tracker = GeneratorTracker::acquire(
auto* tracker = Generators::GeneratorTracker::acquire(
m_trackerContext,
TestCaseTracking::NameAndLocationRef(
generatorName, lineInfo ) );
m_lastKnownLineInfo = lineInfo;
Detail::g_lastKnownLineInfo = lineInfo;
return tracker;
}
@@ -384,12 +420,13 @@ namespace Catch {
return false;
if (m_trackerContext.currentTracker().hasChildren())
return false;
m_totals.assertions.failed++;
m_atomicAssertionCount.failed++;
assertions.failed++;
return true;
}
void RunContext::sectionEnded(SectionEndInfo&& endInfo) {
updateTotalsFromAtomics();
Counts assertions = m_totals.assertions - endInfo.prevAssertions;
bool missingAssertions = testForMissingAssertions(assertions);
@@ -465,6 +502,18 @@ namespace Catch {
}
const AssertionResult * RunContext::getLastResult() const {
// m_lastResult is updated inside the assertion slow-path, under
// a mutex, so the read needs to happen under mutex as well.
// TBD: The last result only makes sense if it is a thread-local
// thing, because the answer is different per thread, like
// last line info, whether last assertion passed, and so on.
//
// However, the last result was also never updated in the
// assertion fast path, so it was always somewhat broken,
// and since IResultCapture::getLastResult is deprecated,
// we will leave it as is, until it is finally removed.
Detail::LockGuard _( m_assertionMutex );
return &(*m_lastResult);
}
@@ -473,13 +522,22 @@ namespace Catch {
}
void RunContext::handleFatalErrorCondition( StringRef message ) {
// TODO: scoped deactivate here? Just give up and do best effort?
// the deactivation can break things further, OTOH so can the
// capture
auto _ = scopedDeactivate( *m_outputRedirect );
// We lock only when touching the reporters directly, to avoid
// deadlocks when we call into other functions that also want
// to lock the mutex before touching reporters.
//
// This does mean that we allow other threads to run while handling
// a fatal error, but this is all a best effort attempt anyway.
{
Detail::LockGuard lock( m_assertionMutex );
// TODO: scoped deactivate here? Just give up and do best effort?
// the deactivation can break things further, OTOH so can the
// capture
auto _ = scopedDeactivate( *m_outputRedirect );
// First notify reporter that bad things happened
m_reporter->fatalErrorEncountered( message );
// First notify reporter that bad things happened
m_reporter->fatalErrorEncountered( message );
}
// Don't rebuild the result -- the stringification itself can cause more fatal errors
// Instead, fake a result data.
@@ -490,6 +548,13 @@ namespace Catch {
assertionEnded(CATCH_MOVE(result) );
// At this point we touch sections/test cases from this thread
// to try and end them. Technically that is not supported when
// using multiple threads, but the worst thing that can happen
// is that the process aborts harder :-D
Detail::LockGuard lock( m_assertionMutex );
// Best effort cleanup for sections that have not been destructed yet
// Since this is a fatal error, we have not had and won't have the opportunity to destruct them properly
while (!m_activeSections.empty()) {
@@ -519,32 +584,44 @@ namespace Catch {
std::string(),
false));
m_totals.testCases.failed++;
updateTotalsFromAtomics();
m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, false));
}
bool RunContext::lastAssertionPassed() {
return m_lastAssertionPassed;
return Detail::g_lastAssertionPassed;
}
void RunContext::assertionPassedFastPath(SourceLineInfo lineInfo) {
m_lastKnownLineInfo = lineInfo;
++m_totals.assertions.passed;
m_lastAssertionPassed = true;
m_clearMessageScopes = true;
// We want to save the line info for better experience with unexpected assertions
Detail::g_lastKnownLineInfo = lineInfo;
++m_atomicAssertionCount.passed;
Detail::g_lastAssertionPassed = true;
Detail::g_clearMessageScopes = true;
}
void RunContext::updateTotalsFromAtomics() {
m_totals.assertions = Counts{
m_atomicAssertionCount.passed,
m_atomicAssertionCount.failed,
m_atomicAssertionCount.failedButOk,
m_atomicAssertionCount.skipped,
};
}
bool RunContext::aborting() const {
return m_totals.assertions.failed >= m_abortAfterXFailedAssertions;
return m_atomicAssertionCount.failed >= m_abortAfterXFailedAssertions;
}
void RunContext::runCurrentTest() {
auto const& testCaseInfo = m_activeTestCase->getTestCaseInfo();
SectionInfo testCaseSection(testCaseInfo.lineInfo, testCaseInfo.name);
m_reporter->sectionStarting(testCaseSection);
updateTotalsFromAtomics();
Counts prevAssertions = m_totals.assertions;
double duration = 0;
m_shouldReportUnexpected = true;
m_lastKnownLineInfo = testCaseInfo.lineInfo;
Detail::g_lastKnownLineInfo = testCaseInfo.lineInfo;
Timer timer;
CATCH_TRY {
@@ -568,6 +645,7 @@ namespace Catch {
dummyReaction );
}
}
updateTotalsFromAtomics();
Counts assertions = m_totals.assertions - prevAssertions;
bool missingAssertions = testForMissingAssertions(assertions);
@@ -617,7 +695,6 @@ namespace Catch {
if( result ) {
if (!m_includeSuccessfulResults) {
// Fast path if neither user nor reporter asked for passing assertions
assertionPassedFastPath(info.lineInfo);
}
else {
@@ -636,7 +713,7 @@ namespace Catch {
ITransientExpression const *expr,
bool negated ) {
m_lastKnownLineInfo = info.lineInfo;
Detail::g_lastKnownLineInfo = info.lineInfo;
AssertionResultData data( resultType, LazyExpression( negated ) );
AssertionResult assertionResult{ info, CATCH_MOVE( data ) };
@@ -651,7 +728,7 @@ namespace Catch {
std::string&& message,
AssertionReaction& reaction
) {
m_lastKnownLineInfo = info.lineInfo;
Detail::g_lastKnownLineInfo = info.lineInfo;
AssertionResultData data( resultType, LazyExpression( false ) );
data.message = CATCH_MOVE( message );
@@ -682,7 +759,7 @@ namespace Catch {
std::string&& message,
AssertionReaction& reaction
) {
m_lastKnownLineInfo = info.lineInfo;
Detail::g_lastKnownLineInfo = info.lineInfo;
AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) );
data.message = CATCH_MOVE(message);
@@ -700,11 +777,12 @@ namespace Catch {
AssertionInfo RunContext::makeDummyAssertionInfo() {
const bool testCaseJustStarted =
m_lastKnownLineInfo == m_activeTestCase->getTestCaseInfo().lineInfo;
Detail::g_lastKnownLineInfo ==
m_activeTestCase->getTestCaseInfo().lineInfo;
return AssertionInfo{
testCaseJustStarted ? "TEST_CASE"_sr : StringRef(),
m_lastKnownLineInfo,
Detail::g_lastKnownLineInfo,
testCaseJustStarted ? StringRef() : "{Unknown expression after the reported line}"_sr,
ResultDisposition::Normal
};
@@ -714,7 +792,7 @@ namespace Catch {
AssertionInfo const& info
) {
using namespace std::string_literals;
m_lastKnownLineInfo = info.lineInfo;
Detail::g_lastKnownLineInfo = info.lineInfo;
AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) );
data.message = "Exception translation was disabled by CATCH_CONFIG_FAST_COMPILE"s;
@@ -727,8 +805,6 @@ namespace Catch {
ResultWas::OfType resultType,
AssertionReaction &reaction
) {
m_lastKnownLineInfo = info.lineInfo;
AssertionResultData data( resultType, LazyExpression( false ) );
AssertionResult assertionResult{ info, CATCH_MOVE( data ) };

View File

@@ -20,6 +20,7 @@
#include <catch2/catch_assertion_result.hpp>
#include <catch2/internal/catch_optional.hpp>
#include <catch2/internal/catch_move_and_forward.hpp>
#include <catch2/internal/catch_thread_support.hpp>
#include <string>
@@ -108,13 +109,14 @@ namespace Catch {
bool lastAssertionPassed() override;
void assertionPassedFastPath(SourceLineInfo lineInfo);
public:
// !TBD We need to do this another way!
bool aborting() const;
private:
void assertionPassedFastPath( SourceLineInfo lineInfo );
// Update the non-thread-safe m_totals from the atomic assertion counts.
void updateTotalsFromAtomics();
void runCurrentTest();
void invokeActiveTestCase();
@@ -138,19 +140,18 @@ namespace Catch {
private:
void handleUnfinishedSections();
mutable Detail::Mutex m_assertionMutex;
TestRunInfo m_runInfo;
TestCaseHandle const* m_activeTestCase = nullptr;
ITracker* m_testCaseTracker = nullptr;
Optional<AssertionResult> m_lastResult;
IConfig const* m_config;
Totals m_totals;
Detail::AtomicCounts m_atomicAssertionCount;
IEventListenerPtr m_reporter;
std::vector<MessageInfo> m_messages;
// Owners for the UNSCOPED_X information macro
std::vector<ScopedMessage> m_messageScopes;
SourceLineInfo m_lastKnownLineInfo;
std::vector<SectionEndInfo> m_unfinishedSections;
std::vector<ITracker*> m_activeSections;
TrackerContext m_trackerContext;
@@ -158,10 +159,6 @@ namespace Catch {
FatalConditionHandler m_fatalConditionhandler;
// Caches m_config->abortAfter() to avoid vptr calls/allow inlining
size_t m_abortAfterXFailedAssertions;
bool m_lastAssertionPassed = false;
// Should we clear message scopes before sending off the messages to reporter?
// Set in `assertionPassedFastPath` to avoid doing the full clear there.
bool m_clearMessageScopes = false;
bool m_shouldReportUnexpected = true;
// Caches whether `assertionStarting` events should be sent to the reporter.
bool m_reportAssertionStarting;

View File

@@ -0,0 +1,49 @@
// Copyright Catch2 Authors
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE.txt or copy at
// https://www.boost.org/LICENSE_1_0.txt)
// SPDX-License-Identifier: BSL-1.0
#ifndef CATCH_THREAD_SUPPORT_HPP_INCLUDED
#define CATCH_THREAD_SUPPORT_HPP_INCLUDED
#include <catch2/catch_user_config.hpp>
#if defined( CATCH_CONFIG_EXPERIMENTAL_THREAD_SAFE_ASSERTIONS )
# include <atomic>
# include <mutex>
#endif
#include <catch2/catch_totals.hpp>
namespace Catch {
namespace Detail {
#if defined( CATCH_CONFIG_EXPERIMENTAL_THREAD_SAFE_ASSERTIONS )
using Mutex = std::mutex;
using LockGuard = std::lock_guard<std::mutex>;
struct AtomicCounts {
std::atomic<std::uint64_t> passed = 0;
std::atomic<std::uint64_t> failed = 0;
std::atomic<std::uint64_t> failedButOk = 0;
std::atomic<std::uint64_t> skipped = 0;
};
#else // ^^ Use actual mutex, lock and atomics
// vv Dummy implementations for single-thread performance
struct Mutex {
void lock() {}
void unlock() {}
};
struct LockGuard {
LockGuard( Mutex ) {}
};
using AtomicCounts = Counts;
#endif
} // namespace Detail
} // namespace Catch
#endif // CATCH_THREAD_SUPPORT_HPP_INCLUDED

View File

@@ -147,6 +147,7 @@ internal_headers = [
'internal/catch_test_registry.hpp',
'internal/catch_test_spec_parser.hpp',
'internal/catch_textflow.hpp',
'internal/catch_thread_support.hpp',
'internal/catch_to_string.hpp',
'internal/catch_uncaught_exceptions.hpp',
'internal/catch_uniform_floating_point_distribution.hpp',