From 80d58a791dd41c365b731469f2ce5f5b6a3ae54d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Ho=C5=99e=C5=88ovsk=C3=BD?= Date: Thu, 20 Oct 2022 21:31:04 +0200 Subject: [PATCH] Add support for Bazel's sharding env variables Closes #2491 --- src/catch2/catch_config.cpp | 114 ++++++++++++++++++++++--- tests/ExtraTests/CMakeLists.txt | 12 +++ tests/TestScripts/testBazelSharding.py | 75 ++++++++++++++++ 3 files changed, 187 insertions(+), 14 deletions(-) create mode 100755 tests/TestScripts/testBazelSharding.py diff --git a/src/catch2/catch_config.cpp b/src/catch2/catch_config.cpp index 52f4c8fd..19a7551a 100644 --- a/src/catch2/catch_config.cpp +++ b/src/catch2/catch_config.cpp @@ -8,38 +8,113 @@ #include #include #include +#include #include +#include #include #include #include #include -namespace { - static bool enableBazelEnvSupport() { -#if defined(CATCH_CONFIG_BAZEL_SUPPORT) - return true; -#elif defined(CATCH_PLATFORM_WINDOWS_UWP) - // UWP does not support environment variables - return false; +#include + +namespace Catch { + + namespace { + static bool enableBazelEnvSupport() { +#if defined( CATCH_CONFIG_BAZEL_SUPPORT ) + return true; +#elif defined( CATCH_PLATFORM_WINDOWS_UWP ) + // UWP does not support environment variables + return false; #else # if defined( _MSC_VER ) - // On Windows getenv throws a warning as there is no input validation, - // since the switch is hardcoded, this should not be an issue. + // On Windows getenv throws a warning as there is no input + // validation, since the switch is hardcoded, this should not be an + // issue. # pragma warning( push ) # pragma warning( disable : 4996 ) # endif - return std::getenv( "BAZEL_TEST" ) != nullptr; + return std::getenv( "BAZEL_TEST" ) != nullptr; # if defined( _MSC_VER ) # pragma warning( pop ) # endif #endif - } -} + } + + struct bazelShardingOptions { + unsigned int shardIndex, shardCount; + std::string shardFilePath; + }; + + static Optional readBazelShardingOptions() { +#if defined( CATCH_PLATFORM_WINDOWS_UWP ) + // We cannot read environment variables on UWP platforms + return {} +#else + +# if defined( _MSC_VER ) +# pragma warning( push ) +# pragma warning( disable : 4996 ) // use getenv_s instead of getenv +# endif + + const auto bazelShardIndex = std::getenv( "TEST_SHARD_INDEX" ); + const auto bazelShardTotal = std::getenv( "TEST_TOTAL_SHARDS" ); + const auto bazelShardInfoFile = std::getenv( "TEST_SHARD_STATUS_FILE" ); + +# if defined( _MSC_VER ) +# pragma warning( pop ) +# endif + + + const bool has_all = + bazelShardIndex && bazelShardTotal && bazelShardInfoFile; + if ( !has_all ) { + // We provide nice warning message if the input is + // misconfigured. + auto warn = []( const char* env_var ) { + Catch::cerr() + << "Warning: Bazel shard configuration is missing '" + << env_var << "'. Shard configuration is skipped.\n"; + }; + if ( !bazelShardIndex ) { + warn( "TEST_SHARD_INDEX" ); + } + if ( !bazelShardTotal ) { + warn( "TEST_TOTAL_SHARDS" ); + } + if ( !bazelShardInfoFile ) { + warn( "TEST_SHARD_STATUS_FILE" ); + } + return {}; + } + + auto shardIndex = parseUInt( bazelShardIndex ); + if ( !shardIndex ) { + Catch::cerr() + << "Warning: could not parse 'TEST_SHARD_INDEX' ('" << bazelShardIndex + << "') as unsigned int.\n"; + return {}; + } + auto shardTotal = parseUInt( bazelShardTotal ); + if ( !shardTotal ) { + Catch::cerr() + << "Warning: could not parse 'TEST_TOTAL_SHARD' ('" + << bazelShardTotal << "') as unsigned int.\n"; + return {}; + } + + return bazelShardingOptions{ + *shardIndex, *shardTotal, bazelShardInfoFile }; + +#endif + + } + } // end namespace -namespace Catch { bool operator==( ProcessedReporterSpec const& lhs, ProcessedReporterSpec const& rhs ) { @@ -184,6 +259,7 @@ namespace Catch { // This allows the XML output file to contain higher level of detail // than what is possible otherwise. const auto bazelOutputFile = std::getenv( "XML_OUTPUT_FILE" ); + if ( bazelOutputFile ) { m_data.reporterSpecifications.push_back( { "junit", std::string( bazelOutputFile ), {}, {} } ); @@ -196,11 +272,21 @@ namespace Catch { m_data.testsOrTags.clear(); m_data.testsOrTags.push_back( bazelTestSpec ); } - # if defined( _MSC_VER ) # pragma warning( pop ) # endif + const auto bazelShardOptions = readBazelShardingOptions(); + if ( bazelShardOptions ) { + std::ofstream f( bazelShardOptions->shardFilePath, + std::ios_base::out | std::ios_base::trunc ); + if ( f.is_open() ) { + f << ""; + m_data.shardIndex = bazelShardOptions->shardIndex; + m_data.shardCount = bazelShardOptions->shardCount; + } + } + #endif } diff --git a/tests/ExtraTests/CMakeLists.txt b/tests/ExtraTests/CMakeLists.txt index 780f2559..a714b230 100644 --- a/tests/ExtraTests/CMakeLists.txt +++ b/tests/ExtraTests/CMakeLists.txt @@ -161,6 +161,18 @@ set_tests_properties(BazelEnv::TESTBRIDGE_TEST_ONLY ) +add_test(NAME BazelEnv::Sharding + COMMAND + "${PYTHON_EXECUTABLE}" "${CATCH_DIR}/tests/TestScripts/testBazelSharding.py" + $ + "${CMAKE_CURRENT_BINARY_DIR}" +) +set_tests_properties(BazelEnv::Sharding + PROPERTIES + LABELS "uses-python" +) + + # The default handler on Windows leads to the just-in-time debugger firing, # which makes this test unsuitable for CI and headless runs, as it opens # up an interactive dialog. diff --git a/tests/TestScripts/testBazelSharding.py b/tests/TestScripts/testBazelSharding.py new file mode 100755 index 00000000..14747bca --- /dev/null +++ b/tests/TestScripts/testBazelSharding.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +# Copyright Catch2 Authors +# Distributed under the Boost Software License, Version 1.0. +# (See accompanying file LICENSE_1_0.txt or copy at +# https://www.boost.org/LICENSE_1_0.txt) + +# SPDX-License-Identifier: BSL-1.0 + +import os +import re +import sys +import subprocess + +""" +Test that Catch2 recognizes the three sharding-related environment variables +and responds accordingly (running only the selected shard, creating the +response file, etc). + +Requires 2 arguments, path to Catch2 binary to run and the output directory +for the output file. +""" +if len(sys.argv) != 3: + print("Wrong number of arguments: {}".format(len(sys.argv))) + print("Usage: {} test-bin-path output-dir".format(sys.argv[0])) + exit(1) + + +bin_path = os.path.abspath(sys.argv[1]) +output_dir = os.path.abspath(sys.argv[2]) +info_file_path = os.path.join(output_dir, '{}.shard-support'.format(os.path.basename(bin_path))) + +# Ensure no file exists from previous test runs +if os.path.isfile(info_file_path): + os.remove(info_file_path) + +print('bin path:', bin_path) +print('shard confirmation path:', info_file_path) + +env = os.environ.copy() +# We will run only one shard, and it should have the passing test. +# This simplifies our work a bit, and if we have any regression in this +# functionality we can make more complex tests later. +env["BAZEL_TEST"] = "1" +env["TEST_SHARD_INDEX"] = "0" +env["TEST_TOTAL_SHARDS"] = "2" +env["TEST_SHARD_STATUS_FILE"] = info_file_path + + +try: + ret = subprocess.run( + bin_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + universal_newlines=True, + env=env + ) + stdout = ret.stdout +except subprocess.SubprocessError as ex: + print('Could not run "{}"'.format(bin_path)) + print("Return code: {}".format(ex.returncode)) + print("stdout: {}".format(ex.stdout)) + print("stderr: {}".format(ex.stderr)) + raise + + +if not "All tests passed (1 assertion in 1 test case)" in stdout: + print("Did not find expected output in stdout.") + print("stdout:\n{}".format(stdout)) + exit(1) + +if not os.path.isfile(info_file_path): + print("Catch2 did not create expected file at path '{}'".format(info_file_path)) + exit(2)