SERVER-109575 Implement loadable test extension for Transform stages (#44570)

GitOrigin-RevId: 8695a2d8739d2b5fa46b480d8df12d459ff7ee39
This commit is contained in:
Adithi Raghavan 2025-12-12 16:28:07 -05:00 committed by MongoDB Bot
parent 8882eea4b6
commit 53adb17228
3 changed files with 397 additions and 1 deletions

View File

@ -0,0 +1,275 @@
/**
* Tests an extension transform stage.
*
* @tags: [featureFlagExtensionsAPI]
*/
import {assertArrayEq} from "jstests/aggregation/extras/utils.js";
const collName = jsTestName();
const coll = db[collName];
coll.drop();
const breadTypeOrderForDocs = ["sourdough", "rye", "whole wheat", "sourdough", "sourdough", "brioche"];
assert.commandWorked(coll.insertMany(breadTypeOrderForDocs.map((i) => ({breadType: i}))));
const runTestcase = (inputPipeline, expectedResults) => {
results = coll.aggregate(buildPipeline(inputPipeline)).toArray();
assertArrayEq({actual: results, expected: expectedResults, extraErrorMsg: tojson(results)});
};
const sortFieldsRemoveId = {
$replaceRoot: {
newRoot: {
$arrayToObject: {
$map: {
input: {$objectToArray: "$$ROOT"},
as: "loafField",
in: {
k: "$$loafField.k",
v: {
// Convert slices object to array, sort by breadType, renumber sequentially, remove _id
$arrayToObject: {
$reduce: {
input: {
$sortArray: {
input: {$objectToArray: "$$loafField.v"},
sortBy: {"v.breadType": 1},
},
},
initialValue: [],
in: {
$concatArrays: [
"$$value",
[
{
k: {$concat: ["slice", {$toString: {$size: "$$value"}}]},
v: {
$arrayToObject: {
$filter: {
input: {$objectToArray: "$$this.v"},
cond: {$ne: ["$$this.k", "_id"]},
},
},
},
},
],
],
},
},
},
},
},
},
},
},
},
};
const buildPipeline = (stages) => {
const result = [];
for (const stage of stages) {
result.push(stage);
if (stage.$loaf !== undefined) {
result.push(sortFieldsRemoveId);
}
}
return result;
};
const buildLoafStage = (numSlices) => {
return {$loaf: {numSlices}};
};
const basicLoafStage = buildLoafStage(2);
const matchWithBasicLoafStage = [{$match: {breadType: "sourdough"}}, basicLoafStage];
// Transform stage must still be run against a collection.
assert.commandFailedWithCode(
db.runCommand({
aggregate: 1,
pipeline: matchWithBasicLoafStage,
cursor: {},
}),
ErrorCodes.InvalidNamespace,
);
// EOF transform stage.
let results = coll.aggregate([{$loaf: {numSlices: 0}}]).toArray();
assert.eq(results.length, 0, results);
// loaf as the first and only stage (set numSlices to collection length so it processes all collection documents).
{
const expectedResults = [
{
fullLoaf: {
slice0: {
breadType: "brioche",
},
slice1: {
breadType: "rye",
},
slice2: {
breadType: "sourdough",
},
slice3: {
breadType: "sourdough",
},
slice4: {
breadType: "sourdough",
},
slice5: {
breadType: "whole wheat",
},
},
},
];
const inputPipeline = [buildLoafStage(breadTypeOrderForDocs.length)];
runTestcase(inputPipeline, expectedResults);
}
// Top-level transform stage with $match in pipeline.
{
const expectedResults = [
{
fullLoaf: {
slice0: {
breadType: "sourdough",
},
slice1: {
breadType: "sourdough",
},
},
},
];
runTestcase(matchWithBasicLoafStage, expectedResults);
}
// Check that a partial loaf is returned (per the getNext() logic for $loaf) when the
// number of docs returned by getNext() on the predecessor stage is less than the number of total
// slices that could be examined. Ex: there is only one matching entry for a breadType of "rye"
// but 2 total slices can be examined. We will hit eof after calling a getNext() for a second time
// on the predecessor stage and will therefore only be able to return a partial loaf with 1 slice
// instead of 2.
{
const expectedResults = [
{
partialLoaf: {
slice0: {
breadType: "rye",
},
},
},
];
const inputPipeline = [{$match: {breadType: "rye"}}, buildLoafStage(2)];
runTestcase(inputPipeline, expectedResults);
}
// $loaf can appear consecutively in a pipeline.
{
const expectedResults = [
{
partialLoaf: {
slice0: {
fullLoaf: {
slice0: {
breadType: "sourdough",
},
slice1: {
breadType: "sourdough",
},
},
},
},
},
];
const inputPipeline = [{$match: {breadType: "sourdough"}}, buildLoafStage(2), buildLoafStage(2)];
runTestcase(inputPipeline, expectedResults);
}
// Extension source stage $toast correctly provides input docs to $loaf.
{
const expectedResults = [
{
fullLoaf: {
slice0: {
slice: 0,
isBurnt: false,
},
slice1: {
slice: 1,
isBurnt: false,
},
},
},
];
results = db.runCommand({
aggregate: "someCollection",
pipeline: buildPipeline([{$toast: {temp: 300.0, numSlices: 4}}, buildLoafStage(2)]),
cursor: {},
}).cursor.firstBatch;
assertArrayEq({actual: results, expected: expectedResults, extraErrorMsg: tojson(results)});
}
// Pipeline: $loaf -> (other server stages) -> $loaf
{
const expectedResults = [
{
partialLoaf: {
slice0: {
count: 1,
},
slice1: {
count: 1,
},
slice2: {
count: 3,
},
slice3: {
count: 1,
},
},
},
];
const inputPipeline = [
buildLoafStage(breadTypeOrderForDocs.length),
{$project: {slices: {$objectToArray: "$fullLoaf"}}}, // Convert object to array
{$unwind: "$slices"}, // Now unwind the array
{$replaceRoot: {newRoot: "$slices.v"}}, // Get the actual slice documents
{$group: {_id: "$breadType", count: {$sum: 1}}}, // Sort by _id (breadType)
{$sort: {_id: 1}},
buildLoafStage(breadTypeOrderForDocs.length),
];
runTestcase(inputPipeline, expectedResults);
}
// TODO SERVER-113930 Remove failure cases and enable success cases for $lookup and $unionWith.
// Transform stage in $lookup.
assert.commandFailedWithCode(
db.runCommand({
aggregate: collName,
pipeline: [{$lookup: {from: collName, pipeline: [{$loaf: {numSlices: 2}}], as: "slices"}}],
cursor: {},
}),
51047,
);
// results = coll.aggregate([{$lookup: {from: collName, pipeline: [{$loaf: {numSlices: 2}}], as: "slices"}}]).toArray();
// assert.gt(results.length, 0);
// Transform stage in $unionWith.
assert.commandFailedWithCode(
db.runCommand({
aggregate: collName,
pipeline: [{$unionWith: {coll: collName, pipeline: [{$loaf: {numSlices: 2}}]}}],
cursor: {},
}),
31441,
);
// results = coll.aggregate([{$unionWith: {coll: collName, pipeline: [{$loaf: {numSlices: 2}}]}}]).toArray();
// assert.gt(results.length, 0);
// Transform stage is not allowed in $facet.
assert.commandFailedWithCode(
db.runCommand({
aggregate: collName,
pipeline: [{$facet: {slices: [{$loaf: {numSlices: 2}}]}}],
cursor: {},
}),
40600,
);

View File

@ -62,7 +62,7 @@ public:
// Expands to three stages: // Expands to three stages:
// 1) Host $match // 1) Host $match
// 2) Host $sort // 2) Host $sort
// 3) Host $limit (TODO SERVER-109575 this should be an extension $limit once transform // 3) Host $limit (TODO SERVER-114847 this should be an extension $limit once transform
// stages are implemented) // stages are implemented)
auto* hostServices = sdk::HostServicesHandle::getHostServices(); auto* hostServices = sdk::HostServicesHandle::getHostServices();
out.emplace_back(hostServices->createHostAggStageParseNode(_matchSpec)); out.emplace_back(hostServices->createHostAggStageParseNode(_matchSpec));

View File

@ -28,8 +28,11 @@
*/ */
#include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobj.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/db/extension/public/extension_agg_stage_static_properties_gen.h" #include "mongo/db/extension/public/extension_agg_stage_static_properties_gen.h"
#include "mongo/db/extension/sdk/aggregation_stage.h" #include "mongo/db/extension/sdk/aggregation_stage.h"
#include "mongo/db/extension/sdk/distributed_plan_logic.h"
#include "mongo/db/extension/sdk/dpl_array_container.h"
#include "mongo/db/extension/sdk/extension_factory.h" #include "mongo/db/extension/sdk/extension_factory.h"
#include "mongo/db/extension/sdk/tests/transform_test_stages.h" #include "mongo/db/extension/sdk/tests/transform_test_stages.h"
@ -83,8 +86,106 @@ private:
int _currentSlice; int _currentSlice;
}; };
class LoafExecStage : public sdk::TestExecStage {
public:
LoafExecStage(std::string_view stageName, const mongo::BSONObj& arguments)
: sdk::TestExecStage(stageName, arguments),
_numSlices([&] {
if (auto numSlices = arguments["numSlices"]) {
return static_cast<int>(numSlices.Number());
}
return 1;
}()),
_currentSlice(0),
_returnedLoaf(false) {}
// Essentially functions like a $group stage (processes multiple input documents via
// getNext() calls on the predecessor stage and outputs them in a single document).
mongo::extension::ExtensionGetNextResult getNext(
const mongo::extension::sdk::QueryExecutionContextHandle& execCtx,
::MongoExtensionExecAggStage* execStage) override {
// Note that exec::agg::Stage::getNext() calls this getNext() method until it gets an eof.
// So if we've already returned a batch of documents, _returnedLoaf should be true, and we
// can return eof.
if (_returnedLoaf) {
return mongo::extension::ExtensionGetNextResult::eof();
}
mongo::BSONObjBuilder loafBuilder;
while (_currentSlice < _numSlices) {
auto input = _getSource().getNext(execCtx.get());
if (input.code == mongo::extension::GetNextCode::kPauseExecution) {
return mongo::extension::ExtensionGetNextResult::pauseExecution();
}
if (input.code == mongo::extension::GetNextCode::kEOF) {
// Return a partial loaf (this means the number of results returned by the
// predecessor stage was less than the total number of slices (_numSlices) that
// could have been processed).
return _buildLoafResult(loafBuilder, "partialLoaf");
}
_appendSliceToLoaf(loafBuilder, input);
}
return _buildLoafResult(loafBuilder, "fullLoaf");
}
private:
int _numSlices;
int _currentSlice;
bool _returnedLoaf;
void _appendSliceToLoaf(mongo::BSONObjBuilder& loafBuilder,
const mongo::extension::ExtensionGetNextResult& input) {
// If we got here, we must have a document!
sdk_tassert(10957500, "Failed to get an input document!", input.resultDocument.has_value());
auto bsonObj = input.resultDocument->getUnownedBSONObj();
mongo::BSONObjBuilder toastBuilder;
toastBuilder.appendElements(bsonObj);
// If the predecessor stage is $loaf, then we are dealing with directly nested loaves,
// so use "loaf" instead of "slice". This is pretty meaningless, I just wanted to check that
// the getName() logic works as expected.
std::string keyPrefix = (_getSource().getName() == getName()) ? "loaf" : "slice";
loafBuilder.append(keyPrefix + std::to_string(_currentSlice++), toastBuilder.done());
}
mongo::extension::ExtensionGetNextResult _buildLoafResult(mongo::BSONObjBuilder& loafBuilder,
const std::string& loafType) {
// Only return a loaf if at least one slice has been transformed.
if (_currentSlice > 0 && !_returnedLoaf) {
_returnedLoaf = true;
auto returnedDoc = loafBuilder.done();
return mongo::extension::ExtensionGetNextResult::advanced(
mongo::extension::ExtensionBSONObj::makeAsByteBuf(BSON(loafType << returnedDoc)));
}
return mongo::extension::ExtensionGetNextResult::eof();
}
};
DEFAULT_LOGICAL_STAGE(Toast); DEFAULT_LOGICAL_STAGE(Toast);
class LoafLogicalStage : public sdk::TestLogicalStage<LoafExecStage> {
public:
LoafLogicalStage(std::string_view stageName, const mongo::BSONObj& arguments)
: sdk::TestLogicalStage<LoafExecStage>(stageName, arguments) {}
std::unique_ptr<sdk::LogicalAggStage> clone() const {
return std::make_unique<LoafLogicalStage>(_name, _arguments);
}
boost::optional<sdk::DistributedPlanLogic> getDistributedPlanLogic() const override {
sdk::DistributedPlanLogic dpl;
// This stage must run on the merging node.
{
std::vector<mongo::extension::VariantDPLHandle> pipeline;
pipeline.emplace_back(mongo::extension::LogicalAggStageHandle{
new sdk::ExtensionLogicalAggStage(clone())});
dpl.mergingPipeline = sdk::DPLArrayContainer(std::move(pipeline));
}
return dpl;
}
};
class ToastAstNode : public sdk::TestAstNode<ToastLogicalStage> { class ToastAstNode : public sdk::TestAstNode<ToastLogicalStage> {
public: public:
ToastAstNode(std::string_view stageName, const mongo::BSONObj& arguments) ToastAstNode(std::string_view stageName, const mongo::BSONObj& arguments)
@ -103,7 +204,25 @@ public:
} }
}; };
DEFAULT_AST_NODE(Loaf);
DEFAULT_PARSE_NODE(Toast); DEFAULT_PARSE_NODE(Toast);
DEFAULT_PARSE_NODE(Loaf);
/**
* $loaf is a transform stage that requires a number of slices, like {$loaf: {numSlices: 5}}.
* This stage processes N documents at a time and returns them where N <= numSlices.
*/
class LoafStageDescriptor : public sdk::TestStageDescriptor<"$loaf", LoafParseNode> {
public:
void validate(const mongo::BSONObj& arguments) const override {
if (auto numSlices = arguments["numSlices"]) {
sdk_uassert(10957501,
"numSlices must be >= 0",
numSlices.isNumber() && numSlices.Number() >= 0);
}
}
};
/** /**
* $toast is a source stage that requires a temperature and number of slices, like {$toast: {temp: * $toast is a source stage that requires a temperature and number of slices, like {$toast: {temp:
@ -148,6 +267,8 @@ public:
// Always register $toast. // Always register $toast.
_registerStage<ToastStageDescriptor>(portal); _registerStage<ToastStageDescriptor>(portal);
// Always register $loaf.
_registerStage<LoafStageDescriptor>(portal);
// Only register $toastBagel if allowBagels is true. // Only register $toastBagel if allowBagels is true.
if (ToasterOptions::allowBagels) { if (ToasterOptions::allowBagels) {