mirror of
https://github.com/catchorg/Catch2.git
synced 2025-08-03 05:45:39 +02:00
Allow test sharding for e.g. Bazel test sharding feature
This greatly simplifies running Catch2 tests in single binary in parallel from external test runners. Instead of having to shard the tests by tags/test names, an external test runner can now just ask for test shard 2 (out of X), and execute that in single process, without having to know what tests are actually in the shard. Note that sharding also applies to test listing, and happens after tests were ordered according to the `--order` feature.
This commit is contained in:

committed by
Martin Hořeňovský

parent
6456ee8b01
commit
3087e19cc7
@@ -85,6 +85,7 @@
|
||||
#include <catch2/internal/catch_result_type.hpp>
|
||||
#include <catch2/internal/catch_run_context.hpp>
|
||||
#include <catch2/internal/catch_section.hpp>
|
||||
#include <catch2/internal/catch_sharding.hpp>
|
||||
#include <catch2/internal/catch_singletons.hpp>
|
||||
#include <catch2/internal/catch_source_line_info.hpp>
|
||||
#include <catch2/internal/catch_startup_exception_registry.hpp>
|
||||
|
@@ -73,6 +73,8 @@ namespace Catch {
|
||||
double Config::minDuration() const { return m_data.minDuration; }
|
||||
TestRunOrder Config::runOrder() const { return m_data.runOrder; }
|
||||
uint32_t Config::rngSeed() const { return m_data.rngSeed; }
|
||||
unsigned int Config::shardCount() const { return m_data.shardCount; }
|
||||
unsigned int Config::shardIndex() const { return m_data.shardIndex; }
|
||||
UseColour Config::useColour() const { return m_data.useColour; }
|
||||
bool Config::shouldDebugBreak() const { return m_data.shouldDebugBreak; }
|
||||
int Config::abortAfter() const { return m_data.abortAfter; }
|
||||
|
@@ -37,6 +37,9 @@ namespace Catch {
|
||||
int abortAfter = -1;
|
||||
uint32_t rngSeed = generateRandomSeed(GenerateFrom::Default);
|
||||
|
||||
unsigned int shardCount = 1;
|
||||
unsigned int shardIndex = 0;
|
||||
|
||||
bool benchmarkNoAnalysis = false;
|
||||
unsigned int benchmarkSamples = 100;
|
||||
double benchmarkConfidenceInterval = 0.95;
|
||||
@@ -99,6 +102,8 @@ namespace Catch {
|
||||
double minDuration() const override;
|
||||
TestRunOrder runOrder() const override;
|
||||
uint32_t rngSeed() const override;
|
||||
unsigned int shardCount() const override;
|
||||
unsigned int shardIndex() const override;
|
||||
UseColour useColour() const override;
|
||||
bool shouldDebugBreak() const override;
|
||||
int abortAfter() const override;
|
||||
|
@@ -16,6 +16,7 @@
|
||||
#include <catch2/catch_version.hpp>
|
||||
#include <catch2/interfaces/catch_interfaces_reporter.hpp>
|
||||
#include <catch2/internal/catch_startup_exception_registry.hpp>
|
||||
#include <catch2/internal/catch_sharding.hpp>
|
||||
#include <catch2/internal/catch_textflow.hpp>
|
||||
#include <catch2/internal/catch_windows_h_proxy.hpp>
|
||||
#include <catch2/reporters/catch_reporter_listening.hpp>
|
||||
@@ -72,6 +73,8 @@ namespace Catch {
|
||||
for (auto const& match : m_matches)
|
||||
m_tests.insert(match.tests.begin(), match.tests.end());
|
||||
}
|
||||
|
||||
m_tests = createShard(m_tests, m_config->shardCount(), m_config->shardIndex());
|
||||
}
|
||||
|
||||
Totals execute() {
|
||||
@@ -171,6 +174,7 @@ namespace Catch {
|
||||
return 1;
|
||||
|
||||
auto result = m_cli.parse( Clara::Args( argc, argv ) );
|
||||
|
||||
if( !result ) {
|
||||
config();
|
||||
getCurrentMutableContext().setConfig(m_config.get());
|
||||
@@ -253,6 +257,12 @@ namespace Catch {
|
||||
if( m_startupExceptions )
|
||||
return 1;
|
||||
|
||||
|
||||
if( m_configData.shardIndex >= m_configData.shardCount ) {
|
||||
Catch::cerr() << "The shard count (" << m_configData.shardCount << ") must be greater than the shard index (" << m_configData.shardIndex << ")\n" << std::flush;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (m_configData.showHelp || m_configData.libIdentify) {
|
||||
return 0;
|
||||
}
|
||||
|
@@ -74,6 +74,8 @@ namespace Catch {
|
||||
virtual std::vector<std::string> const& getTestsOrTags() const = 0;
|
||||
virtual TestRunOrder runOrder() const = 0;
|
||||
virtual uint32_t rngSeed() const = 0;
|
||||
virtual unsigned int shardCount() const = 0;
|
||||
virtual unsigned int shardIndex() const = 0;
|
||||
virtual UseColour useColour() const = 0;
|
||||
virtual std::vector<std::string> const& getSectionsToRun() const = 0;
|
||||
virtual Verbosity verbosity() const = 0;
|
||||
|
@@ -149,6 +149,15 @@ namespace Catch {
|
||||
return ParserResult::runtimeError( "Unrecognized reporter, '" + reporter + "'. Check available with --list-reporters" );
|
||||
return ParserResult::ok( ParseResultType::Matched );
|
||||
};
|
||||
auto const setShardCount = [&]( std::string const& shardCount ) {
|
||||
auto result = Clara::Detail::convertInto( shardCount, config.shardCount );
|
||||
|
||||
if (config.shardCount == 0) {
|
||||
return ParserResult::runtimeError( "The shard count must be greater than 0" );
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
auto cli
|
||||
= ExeName( config.processName )
|
||||
@@ -240,6 +249,12 @@ namespace Catch {
|
||||
| Opt( config.benchmarkWarmupTime, "benchmarkWarmupTime" )
|
||||
["--benchmark-warmup-time"]
|
||||
( "amount of time in milliseconds spent on warming up each test (default: 100)" )
|
||||
| Opt( setShardCount, "shard count" )
|
||||
["--shard-count"]
|
||||
( "split the tests to execute into this many groups" )
|
||||
| Opt( config.shardIndex, "shard index" )
|
||||
["--shard-index"]
|
||||
( "index of the group of tests to execute (see --shard-count)" )
|
||||
| Arg( config.testsOrTags, "test name|pattern|tags" )
|
||||
( "which test or tests to use" );
|
||||
|
||||
|
41
src/catch2/internal/catch_sharding.hpp
Normal file
41
src/catch2/internal/catch_sharding.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
|
||||
// Copyright Catch2 Authors
|
||||
// Distributed under the Boost Software License, Version 1.0.
|
||||
// (See accompanying file LICENSE_1_0.txt or copy at
|
||||
// https://www.boost.org/LICENSE_1_0.txt)
|
||||
|
||||
// SPDX-License-Identifier: BSL-1.0
|
||||
#ifndef CATCH_SHARDING_HPP_INCLUDED
|
||||
#define CATCH_SHARDING_HPP_INCLUDED
|
||||
|
||||
#include <catch2/catch_session.hpp>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace Catch {
|
||||
|
||||
template<typename Container>
|
||||
Container createShard(Container const& container, std::size_t const shardCount, std::size_t const shardIndex) {
|
||||
assert(shardCount > shardIndex);
|
||||
|
||||
if (shardCount == 1) {
|
||||
return container;
|
||||
}
|
||||
|
||||
const std::size_t totalTestCount = container.size();
|
||||
|
||||
const std::size_t shardSize = totalTestCount / shardCount;
|
||||
const std::size_t leftoverTests = totalTestCount % shardCount;
|
||||
|
||||
const std::size_t startIndex = shardIndex * shardSize + (std::min)(shardIndex, leftoverTests);
|
||||
const std::size_t endIndex = (shardIndex + 1) * shardSize + (std::min)(shardIndex + 1, leftoverTests);
|
||||
|
||||
auto startIterator = std::next(container.begin(), startIndex);
|
||||
auto endIterator = std::next(container.begin(), endIndex);
|
||||
|
||||
return Container(startIterator, endIterator);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif // CATCH_SHARDING_HPP_INCLUDED
|
@@ -12,6 +12,7 @@
|
||||
#include <catch2/interfaces/catch_interfaces_registry_hub.hpp>
|
||||
#include <catch2/internal/catch_random_number_generator.hpp>
|
||||
#include <catch2/internal/catch_run_context.hpp>
|
||||
#include <catch2/internal/catch_sharding.hpp>
|
||||
#include <catch2/catch_test_case_info.hpp>
|
||||
#include <catch2/catch_test_spec.hpp>
|
||||
#include <catch2/internal/catch_move_and_forward.hpp>
|
||||
@@ -135,7 +136,7 @@ namespace {
|
||||
filtered.push_back(testCase);
|
||||
}
|
||||
}
|
||||
return filtered;
|
||||
return createShard(filtered, config.shardCount(), config.shardIndex());
|
||||
}
|
||||
std::vector<TestCaseHandle> const& getAllTestCasesSorted( IConfig const& config ) {
|
||||
return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config );
|
||||
|
Reference in New Issue
Block a user