Allow test sharding for e.g. Bazel test sharding feature

This greatly simplifies running Catch2 tests in single binary
in parallel from external test runners. Instead of having to
shard the tests by tags/test names, an external test runner
can now just ask for test shard 2 (out of X), and execute that
in single process, without having to know what tests are actually
in the shard.

Note that sharding also applies to test listing, and happens after
tests were ordered according to the `--order` feature.
This commit is contained in:
Ben Dunkin
2021-07-11 12:46:05 -07:00
committed by Martin Hořeňovský
parent 6456ee8b01
commit 3087e19cc7
21 changed files with 415 additions and 6 deletions

View File

@@ -85,6 +85,7 @@
#include <catch2/internal/catch_result_type.hpp>
#include <catch2/internal/catch_run_context.hpp>
#include <catch2/internal/catch_section.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/internal/catch_singletons.hpp>
#include <catch2/internal/catch_source_line_info.hpp>
#include <catch2/internal/catch_startup_exception_registry.hpp>

View File

@@ -73,6 +73,8 @@ namespace Catch {
double Config::minDuration() const { return m_data.minDuration; }
TestRunOrder Config::runOrder() const { return m_data.runOrder; }
uint32_t Config::rngSeed() const { return m_data.rngSeed; }
unsigned int Config::shardCount() const { return m_data.shardCount; }
unsigned int Config::shardIndex() const { return m_data.shardIndex; }
UseColour Config::useColour() const { return m_data.useColour; }
bool Config::shouldDebugBreak() const { return m_data.shouldDebugBreak; }
int Config::abortAfter() const { return m_data.abortAfter; }

View File

@@ -37,6 +37,9 @@ namespace Catch {
int abortAfter = -1;
uint32_t rngSeed = generateRandomSeed(GenerateFrom::Default);
unsigned int shardCount = 1;
unsigned int shardIndex = 0;
bool benchmarkNoAnalysis = false;
unsigned int benchmarkSamples = 100;
double benchmarkConfidenceInterval = 0.95;
@@ -99,6 +102,8 @@ namespace Catch {
double minDuration() const override;
TestRunOrder runOrder() const override;
uint32_t rngSeed() const override;
unsigned int shardCount() const override;
unsigned int shardIndex() const override;
UseColour useColour() const override;
bool shouldDebugBreak() const override;
int abortAfter() const override;

View File

@@ -16,6 +16,7 @@
#include <catch2/catch_version.hpp>
#include <catch2/interfaces/catch_interfaces_reporter.hpp>
#include <catch2/internal/catch_startup_exception_registry.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/internal/catch_textflow.hpp>
#include <catch2/internal/catch_windows_h_proxy.hpp>
#include <catch2/reporters/catch_reporter_listening.hpp>
@@ -72,6 +73,8 @@ namespace Catch {
for (auto const& match : m_matches)
m_tests.insert(match.tests.begin(), match.tests.end());
}
m_tests = createShard(m_tests, m_config->shardCount(), m_config->shardIndex());
}
Totals execute() {
@@ -171,6 +174,7 @@ namespace Catch {
return 1;
auto result = m_cli.parse( Clara::Args( argc, argv ) );
if( !result ) {
config();
getCurrentMutableContext().setConfig(m_config.get());
@@ -253,6 +257,12 @@ namespace Catch {
if( m_startupExceptions )
return 1;
if( m_configData.shardIndex >= m_configData.shardCount ) {
Catch::cerr() << "The shard count (" << m_configData.shardCount << ") must be greater than the shard index (" << m_configData.shardIndex << ")\n" << std::flush;
return 1;
}
if (m_configData.showHelp || m_configData.libIdentify) {
return 0;
}

View File

@@ -74,6 +74,8 @@ namespace Catch {
virtual std::vector<std::string> const& getTestsOrTags() const = 0;
virtual TestRunOrder runOrder() const = 0;
virtual uint32_t rngSeed() const = 0;
virtual unsigned int shardCount() const = 0;
virtual unsigned int shardIndex() const = 0;
virtual UseColour useColour() const = 0;
virtual std::vector<std::string> const& getSectionsToRun() const = 0;
virtual Verbosity verbosity() const = 0;

View File

@@ -149,6 +149,15 @@ namespace Catch {
return ParserResult::runtimeError( "Unrecognized reporter, '" + reporter + "'. Check available with --list-reporters" );
return ParserResult::ok( ParseResultType::Matched );
};
auto const setShardCount = [&]( std::string const& shardCount ) {
auto result = Clara::Detail::convertInto( shardCount, config.shardCount );
if (config.shardCount == 0) {
return ParserResult::runtimeError( "The shard count must be greater than 0" );
} else {
return result;
}
};
auto cli
= ExeName( config.processName )
@@ -240,6 +249,12 @@ namespace Catch {
| Opt( config.benchmarkWarmupTime, "benchmarkWarmupTime" )
["--benchmark-warmup-time"]
( "amount of time in milliseconds spent on warming up each test (default: 100)" )
| Opt( setShardCount, "shard count" )
["--shard-count"]
( "split the tests to execute into this many groups" )
| Opt( config.shardIndex, "shard index" )
["--shard-index"]
( "index of the group of tests to execute (see --shard-count)" )
| Arg( config.testsOrTags, "test name|pattern|tags" )
( "which test or tests to use" );

View File

@@ -0,0 +1,41 @@
// Copyright Catch2 Authors
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// https://www.boost.org/LICENSE_1_0.txt)
// SPDX-License-Identifier: BSL-1.0
#ifndef CATCH_SHARDING_HPP_INCLUDED
#define CATCH_SHARDING_HPP_INCLUDED
#include <catch2/catch_session.hpp>
#include <cmath>
namespace Catch {
template<typename Container>
Container createShard(Container const& container, std::size_t const shardCount, std::size_t const shardIndex) {
assert(shardCount > shardIndex);
if (shardCount == 1) {
return container;
}
const std::size_t totalTestCount = container.size();
const std::size_t shardSize = totalTestCount / shardCount;
const std::size_t leftoverTests = totalTestCount % shardCount;
const std::size_t startIndex = shardIndex * shardSize + (std::min)(shardIndex, leftoverTests);
const std::size_t endIndex = (shardIndex + 1) * shardSize + (std::min)(shardIndex + 1, leftoverTests);
auto startIterator = std::next(container.begin(), startIndex);
auto endIterator = std::next(container.begin(), endIndex);
return Container(startIterator, endIterator);
}
}
#endif // CATCH_SHARDING_HPP_INCLUDED

View File

@@ -12,6 +12,7 @@
#include <catch2/interfaces/catch_interfaces_registry_hub.hpp>
#include <catch2/internal/catch_random_number_generator.hpp>
#include <catch2/internal/catch_run_context.hpp>
#include <catch2/internal/catch_sharding.hpp>
#include <catch2/catch_test_case_info.hpp>
#include <catch2/catch_test_spec.hpp>
#include <catch2/internal/catch_move_and_forward.hpp>
@@ -135,7 +136,7 @@ namespace {
filtered.push_back(testCase);
}
}
return filtered;
return createShard(filtered, config.shardCount(), config.shardIndex());
}
std::vector<TestCaseHandle> const& getAllTestCasesSorted( IConfig const& config ) {
return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config );