diff --git a/jstests/aggregation/sbe/sbe_sort_extract_field_paths_non_bson_object.js b/jstests/aggregation/sbe/sbe_sort_extract_field_paths_non_bson_object.js new file mode 100644 index 00000000000..395b41a6a59 --- /dev/null +++ b/jstests/aggregation/sbe/sbe_sort_extract_field_paths_non_bson_object.js @@ -0,0 +1,13 @@ +/** + * When run in SBE, extract_field_paths iterates over the Object + * typeTag for `m` to calculate `$m.m1`. Tests we don't tassert. + */ +assert.commandWorked(db.c.createIndex({"m.m1": 1})); +assert.commandWorked( + db.c.insertOne({ + "m": {"m1": NumberInt(0), "m2": NumberInt(0)}, + }), +); +assert.eq(db.c.aggregate([{"$sort": {"m.m1": 1}}, {"$group": {"_id": null, "a": {"$min": "$m.m1"}}}]).toArray(), [ + {"_id": null, "a": NumberInt(0)}, +]); diff --git a/src/mongo/db/exec/sbe/stages/extract_field_paths.cpp b/src/mongo/db/exec/sbe/stages/extract_field_paths.cpp index bfdb5bba213..456ed7945d0 100644 --- a/src/mongo/db/exec/sbe/stages/extract_field_paths.cpp +++ b/src/mongo/db/exec/sbe/stages/extract_field_paths.cpp @@ -132,8 +132,15 @@ PlanState ExtractFieldPathsStage::getNext() { if (_root->inputAccessor) { // Should only be used for unit tests. auto [inputTag, inputVal] = _root->inputAccessor->getViewOfValue(); - value::walkObj( - _root.get(), inputTag, inputVal, value::bitcastTo(inputVal), walk); + + if (value::TypeTags::bsonObject == inputTag) { + value::walkBsonObj( + _root.get(), inputVal, value::bitcastTo(inputVal), walk); + } else if (value::TypeTags::Object == inputTag) { + value::walkObject( + _root.get(), inputVal, walk); + } + } else { // Important this is only for toplevel fields. For nested fields, we would need knowledge of // arrayness. We would also need to check for input accessors during the tree traversal. diff --git a/src/mongo/db/exec/sbe/values/bson_block.cpp b/src/mongo/db/exec/sbe/values/bson_block.cpp index 4e40ec23667..7890ea3575a 100644 --- a/src/mongo/db/exec/sbe/values/bson_block.cpp +++ b/src/mongo/db/exec/sbe/values/bson_block.cpp @@ -142,11 +142,10 @@ std::vector> BSONExtractorImpl::extractFromBsons( rec.newDoc(); } - walkObj(&_root, - TypeTags::bsonObject, - bitcastFrom(obj.objdata()), - obj.objdata(), - visitElementExtractorCallback); + walkBsonObj(&_root, + bitcastFrom(obj.objdata()), + obj.objdata(), + visitElementExtractorCallback); for (auto& rec : _filterPositionInfoRecorders) { rec.endDoc(); @@ -221,11 +220,10 @@ std::vector extractValuePointersFromBson(BSONObj& obj, } }; - walkObj(extractor.getRoot(), - TypeTags::bsonObject, - bitcastFrom(obj.objdata()), - obj.objdata(), - recordValuePointer); + walkBsonObj(extractor.getRoot(), + bitcastFrom(obj.objdata()), + obj.objdata(), + recordValuePointer); return bsonPointers; } } // namespace mongo::sbe::value diff --git a/src/mongo/db/exec/sbe/values/object_walk_node.h b/src/mongo/db/exec/sbe/values/object_walk_node.h index f4f4c887cc8..dc6236587cc 100644 --- a/src/mongo/db/exec/sbe/values/object_walk_node.h +++ b/src/mongo/db/exec/sbe/values/object_walk_node.h @@ -257,27 +257,42 @@ void walkField(ObjectWalkNode* node, template requires std::invocable*, TypeTags, Value, const char*> -void walkObj(ObjectWalkNode* node, - value::TypeTags inputTag, - value::Value inputVal, - const char* bsonPtr, - const Cb& cb) { +void walkBsonObj(ObjectWalkNode* node, + value::Value inputVal, + const char* bsonPtr, + const Cb& cb) { size_t numChildrenWalked = 0; - auto callback = [&](StringData currFieldName, - value::TypeTags tag, - value::Value val, - const char* cur) -> bool { - if (numChildrenWalked >= node->getChildren.size()) { - // Early exit because we've walked every child for this node. - return true; - } - if (auto it = node->getChildren.find(currFieldName); it != node->getChildren.end()) { - walkField(it->second.get(), tag, val, cur, cb); + auto bson = value::getRawPointerView(inputVal); + const auto end = bson::bsonEnd(bson); + + // Skip document length. + const char* be = bson + 4; + while (numChildrenWalked < node->getChildren.size() && be != end - 1) { + auto fieldName = bson::fieldNameAndLength(be); + if (auto it = node->getChildren.find(fieldName); it != node->getChildren.end()) { + auto [eltTag, eltVal] = bson::convertFrom(be, end, fieldName.size()); + walkField(it->second.get(), eltTag, eltVal, be, cb); numChildrenWalked++; } - return false; - }; - value::objectForEach(inputTag, inputVal, callback); + be = bson::advance(be, fieldName.size()); + } +} + +template +void walkObject(ObjectWalkNode* node, value::Value inputVal, const Cb& cb) { + size_t numChildrenWalked = 0; + auto obj = getObjectView(inputVal); + + size_t i = 0; + while (numChildrenWalked < node->getChildren.size() && i < obj->size()) { + if (auto it = node->getChildren.find(obj->field(i)); it != node->getChildren.end()) { + auto [eltTag, eltVal] = obj->getAt(i); + walkField( + it->second.get(), eltTag, eltVal, nullptr /*bsonPtr*/, cb); + numChildrenWalked++; + } + i++; + } } template @@ -287,13 +302,12 @@ void walkField(ObjectWalkNode* node, Value eltVal, const char* bsonPtr, const Cb& cb) { - if (value::isObject(eltTag)) { - walkObj(node, eltTag, eltVal, bsonPtr, cb); - if (node->traverseChild) { - walkField( - node->traverseChild.get(), eltTag, eltVal, bsonPtr, cb); - } - } else if (value::isArray(eltTag)) { + if (value::TypeTags::bsonObject == eltTag) { + walkBsonObj(node, eltVal, bsonPtr, cb); + } else if (value::TypeTags::Object == eltTag) { + walkObject(node, eltVal, cb); + } + if (value::isArray(eltTag)) { if (node->traverseChild) { // The projection traversal semantics are "special" in that the leaf must know // when there is an array higher up in the tree. @@ -322,7 +336,7 @@ void walkField(ObjectWalkNode* node, } } } else if (node->traverseChild) { - // We didn't see an array, so we apply the node below the traverse to this scalar. + // We didn't see an array, so we apply the node below the traverse. walkField(node->traverseChild.get(), eltTag, eltVal, bsonPtr, cb); } // Some callbacks use the raw bson pointer, not just the tag and value.