SERVER-110208 Implement an optimization budget (#41870)

GitOrigin-RevId: 30535cac283d11eaec0419e86a1b5abe8fc9cef7
This commit is contained in:
Henri Nikku 2025-10-15 20:47:18 +01:00 committed by MongoDB Bot
parent cdf4f33848
commit b31ac5e6d1
5 changed files with 152 additions and 34 deletions

View File

@ -33,5 +33,6 @@ mongo_cc_unit_test(
":pipeline_rewriter", ":pipeline_rewriter",
"//src/mongo/db/pipeline:aggregation_context_fixture", "//src/mongo/db/pipeline:aggregation_context_fixture",
"//src/mongo/db/pipeline:expression_context_for_test", "//src/mongo/db/pipeline:expression_context_for_test",
"//src/mongo/idl:server_parameter_test_controller",
], ],
) )

View File

@ -31,6 +31,8 @@
#include "mongo/db/pipeline/document_source_limit.h" #include "mongo/db/pipeline/document_source_limit.h"
#include "mongo/db/pipeline/document_source_skip.h" #include "mongo/db/pipeline/document_source_skip.h"
#include "mongo/db/pipeline/optimization/rule_based_rewriter.h" #include "mongo/db/pipeline/optimization/rule_based_rewriter.h"
#include "mongo/db/query/query_knobs_gen.h"
#include "mongo/idl/server_parameter_test_controller.h"
#include "mongo/unittest/death_test.h" #include "mongo/unittest/death_test.h"
#include "mongo/unittest/unittest.h" #include "mongo/unittest/unittest.h"
@ -63,7 +65,8 @@ void runTest(const boost::intrusive_ptr<ExpressionContext>& expCtx,
}; };
auto pipeline = makePipeline(input); auto pipeline = makePipeline(input);
PipelineRewriteEngine engine{{*pipeline}}; PipelineRewriteEngine engine{{*pipeline},
static_cast<size_t>(internalQueryMaxPipelineRewrites.load())};
engine.applyRules(); engine.applyRules();
@ -101,6 +104,17 @@ TEST_F(PipelineRewriteEngineTest, RespectPriority) {
runTest(getExpCtx(), {"{$match: {a: 1}}"}, {}); runTest(getExpCtx(), {"{$match: {a: 1}}"}, {});
} }
TEST_F(PipelineRewriteEngineTest, RespectMaxRewritesQueryKnob) {
RAIIServerParameterControllerForTest controller("internalQueryMaxPipelineRewrites", 2);
REGISTER_TEST_RULES(DocumentSourceMatch,
{"CRASH_WHEN_RUN", alwaysTrue, shouldNeverRun, 1.0},
{"NOOP1", alwaysTrue, noop, 2.0},
{"NOOP2", alwaysTrue, noop, 3.0});
runTest(getExpCtx(), {"{$match: {a: 1}}"}, {"{$match: {a: 1}}"});
}
TEST_F(PipelineRewriteEngineTest, ApplySingleRuleInPlace) { TEST_F(PipelineRewriteEngineTest, ApplySingleRuleInPlace) {
static auto setLimitTo1 = [](PipelineRewriteContext& ctx) { static auto setLimitTo1 = [](PipelineRewriteContext& ctx) {
ctx.currentAs<DocumentSourceLimit>().setLimit(1); ctx.currentAs<DocumentSourceLimit>().setLimit(1);

View File

@ -29,6 +29,7 @@
#pragma once #pragma once
#include "mongo/base/checked_cast.h"
#include "mongo/logv2/log.h" #include "mongo/logv2/log.h"
#include "mongo/util/modules.h" #include "mongo/util/modules.h"
@ -123,11 +124,11 @@ public:
*/ */
template <std::derived_from<T> SubType> template <std::derived_from<T> SubType>
SubType& currentAs() { SubType& currentAs() {
return static_cast<SubType&>(current()); return checked_cast<SubType&>(current());
} }
template <std::derived_from<T> SubType> template <std::derived_from<T> SubType>
const SubType& currentAs() const { const SubType& currentAs() const {
return static_cast<const SubType&>(current()); return checked_cast<const SubType&>(current());
} }
void setEngine(RewriteEngine<SubClass>& engine) { void setEngine(RewriteEngine<SubClass>& engine) {
@ -146,7 +147,8 @@ private:
template <typename Context> template <typename Context>
class RewriteEngine final { class RewriteEngine final {
public: public:
RewriteEngine(Context context) : _context(std::move(context)) { RewriteEngine(Context context, size_t maxRewrites = std::numeric_limits<size_t>::max())
: _context(std::move(context)), _maxRewrites(maxRewrites) {
_context.setEngine(*this); _context.setEngine(*this);
} }
@ -169,47 +171,84 @@ public:
// rules. // rules.
_context.enqueueRules(); _context.enqueueRules();
bool doAdvance = true; NextAction nextAction = rewriteCurrentPosition();
while (!_rules.empty() && _context.hasMore()) { switch (nextAction) {
const auto rule = std::move(_rules.top()); case NextAction::Requeue:
_rules.pop(); // Requeue rules that apply to the current element without advancing.
const size_t rulesBefore = _rules.size(); break;
case NextAction::Advance:
LOGV2_DEBUG(11010013, // Did not update position. Advance to the next element.
5, _context.advance();
"Trying to apply a rewrite rule", break;
"rule"_attr = rule.name, case NextAction::Bail:
"priority"_attr = rule.priority); return;
if (rule.precondition(_context)) {
const bool shouldRequeueRules = rule.transform(_context);
if (shouldRequeueRules) {
tassert(11010015,
"Should not add new rules from a rule that requires requeueing",
rulesBefore == _rules.size());
// Discard remaining rules because we changed position.
clearRules();
doAdvance = false;
break;
}
}
}
if (doAdvance) {
// Did not update position. Advance to the next element.
_context.advance();
} }
} }
} }
private: private:
enum class NextAction {
Requeue,
Advance,
Bail,
};
/**
* Try to apply all rules in the queue to the current position.
*/
NextAction rewriteCurrentPosition() {
while (!_rules.empty() && _context.hasMore()) {
if (_maxRewrites <= _rewritesApplied) {
LOGV2_DEBUG(11020801,
5,
"Reached the maximum number of rewrites applied",
"limit"_attr = _maxRewrites);
return NextAction::Bail;
}
const auto rule = std::move(_rules.top());
_rules.pop();
const size_t rulesBefore = _rules.size();
LOGV2_DEBUG(11010013,
5,
"Trying to apply a rewrite rule",
"rule"_attr = rule.name,
"priority"_attr = rule.priority);
if (!rule.precondition(_context)) {
// Continue to the next applicable rule.
continue;
}
const bool shouldRequeueRules = rule.transform(_context);
_rewritesApplied++;
LOGV2_DEBUG(11206202, 5, "Applied rule", "rule"_attr = rule.name);
if (shouldRequeueRules) {
tassert(11010015,
"Should not add new rules from a rule that requires requeueing",
rulesBefore == _rules.size());
// Discard remaining rules and requeue because we changed position.
clearRules();
return NextAction::Requeue;
}
}
return NextAction::Advance;
}
void clearRules() { void clearRules() {
_rules = {}; _rules = {};
} }
Context _context; Context _context;
std::priority_queue<Rule<Context>> _rules; std::priority_queue<Rule<Context>> _rules;
const size_t _maxRewrites;
size_t _rewritesApplied{0};
}; };
} // namespace mongo::rule_based_rewrites } // namespace mongo::rule_based_rewrites

View File

@ -88,6 +88,9 @@ bool alwaysTrue(TestRewriteContext&) {
} }
// Transforms // Transforms
bool noop(TestRewriteContext& ctx) {
return false;
}
bool shouldNeverRun(TestRewriteContext& ctx) { bool shouldNeverRun(TestRewriteContext& ctx) {
MONGO_UNREACHABLE; MONGO_UNREACHABLE;
} }
@ -110,6 +113,56 @@ TEST(RuleBasedRewriterTest, RespectPrecondition) {
ASSERT_DOES_NOT_THROW(engine.applyRules()); ASSERT_DOES_NOT_THROW(engine.applyRules());
} }
TEST(RuleBasedRewriterTest, RespectZeroOptimizationBudget) {
std::vector<std::string> strings = {{"hello"}};
RewriteEngine<TestRewriteContext> engine{
{
strings,
{{"NEVER_APPLIES", alwaysTrue, shouldNeverRun, 1}},
},
0 /*maxRewrites*/,
};
ASSERT_DOES_NOT_THROW(engine.applyRules());
}
TEST(RuleBasedRewriterTest, RespectNonZeroOptimizationBudget) {
std::vector<std::string> strings = {{"hello"}};
RewriteEngine<TestRewriteContext> engine{
{
strings,
{
{"NOOP1", alwaysTrue, noop, 3},
{"NOOP2", alwaysTrue, noop, 2},
{"NEVER_APPLIES", alwaysTrue, shouldNeverRun, 1},
},
},
2 /*maxRewrites*/,
};
ASSERT_DOES_NOT_THROW(engine.applyRules());
}
TEST(RuleBasedRewriterTest, ApplyAllRulesBeforeReachingOptimizationBudget) {
std::vector<std::string> strings = {{"hello"}};
RewriteEngine<TestRewriteContext> engine{
{
strings,
{
{"NOOP1", alwaysTrue, noop, 3},
{"NOOP2", alwaysTrue, noop, 2},
{"SHOULD_APPLY", alwaysTrue, upperCaseTransform, 1},
},
},
3 /*maxRewrites*/,
};
engine.applyRules();
ASSERT_EQ(strings.size(), 1U);
ASSERT_EQ(strings[0], "HELLO");
}
TEST(RuleBasedRewriterTest, ApplySingleRule) { TEST(RuleBasedRewriterTest, ApplySingleRule) {
std::vector<std::string> strings = {{"hello"}, {"world"}}; std::vector<std::string> strings = {{"hello"}, {"world"}};
RewriteEngine<TestRewriteContext> engine{{ RewriteEngine<TestRewriteContext> engine{{

View File

@ -520,6 +520,17 @@ server_parameters:
on_update: plan_cache_util::clearSbeCacheOnParameterChange on_update: plan_cache_util::clearSbeCacheOnParameterChange
redact: false redact: false
internalQueryMaxPipelineRewrites:
description: "Maximum number of pipeline rewrite rules to apply"
set_at: [startup, runtime]
cpp_varname: "internalQueryMaxPipelineRewrites"
cpp_vartype: AtomicWord<int>
default: 1_000_000
validator:
gte: 0
on_update: plan_cache_util::clearSbeCacheOnParameterChange
redact: false
# #
# Query execution # Query execution
# #