SERVER-108008: Simplify $lookup compilation in SBE (#41258)

GitOrigin-RevId: 93becbbc1c9ed62ed9ba75b05d3b7ae7744d3817
This commit is contained in:
Alberto Massari 2025-09-24 14:28:06 +02:00 committed by MongoDB Bot
parent ff87ad5f2a
commit 9aff7fb103
44 changed files with 914 additions and 597 deletions

View File

@ -380,7 +380,7 @@ class DatabaseNamePrinter(object):
tenant = data[0] & TENANT_ID_MASK
if tenant:
return f"{extract_tenant_id(data)}_{data[1+OBJECT_ID_WIDTH:].decode()}"
return f"{extract_tenant_id(data)}_{data[1 + OBJECT_ID_WIDTH :].decode()}"
else:
return data[1:].decode()
@ -1027,7 +1027,12 @@ class SbeCodeFragmentPrinter(object):
# Some instructions have extra arguments, embedded into the ops stream.
args = ""
if op_name in ["pushLocalVal", "pushMoveLocalVal", "pushLocalLambda"]:
if op_name in [
"pushLocalVal",
"pushMoveLocalVal",
"pushOneArgLambda",
"pushTwoArgLambda",
]:
args = "arg: " + str(read_as_integer(cur_op, int_size))
cur_op += int_size
elif op_name in ["jmp", "jmpTrue", "jmpFalse", "jmpNothing", "jmpNotNothing"]:
@ -1068,10 +1073,14 @@ class SbeCodeFragmentPrinter(object):
args = "Instruction::Constants: " + str(read_as_integer(cur_op, uint8_size))
cur_op += uint8_size
elif op_name in ["traverseFImm", "traversePImm"]:
position = read_as_integer(cur_op, uint8_size)
cur_op += uint8_size
const_enum = read_as_integer(cur_op, uint8_size)
cur_op += uint8_size
args = (
"Instruction::Constants: "
"providePosition: "
+ str(position)
+ ", Instruction::Constants: "
+ str(const_enum)
+ ", offset: "
+ str(read_as_integer_signed(cur_op, int_size))

View File

@ -286,15 +286,12 @@ class LambdaAbstractionPrinter(FixedArityNodePrinter):
super().__init__(val, 1, "LambdaAbstraction")
def to_string(self):
return "LambdaAbstraction[{}]".format(self.val["_varName"])
class LambdaApplicationPrinter(FixedArityNodePrinter):
"""Pretty-printer for LambdaApplication."""
def __init__(self, val):
"""Initialize LambdaApplicationPrinter."""
super().__init__(val, 2, "LambdaApplication")
res = "LambdaAbstraction[{"
bindings = Vector(self.val["_varNames"])
for name in bindings:
res += str(name) + " "
res += "}]"
return res
class SourcePrinter(FixedArityNodePrinter):
@ -408,7 +405,6 @@ def register_optimizer_printers(pp):
"Let",
"MultiLet",
"LambdaAbstraction",
"LambdaApplication",
"FunctionCall",
"Source",
"Switch",

View File

@ -119,9 +119,9 @@ runTests({
})();
/**
* Other miscelaneous tests for INLJ.
* Other miscellaneous tests for INLJ.
*/
(function runMiscelaneousInljTests() {
(function runMiscellaneousInljTests() {
const testConfig = {
localColl: db.lookup_arrays_semantics_local_inlj,
foreignColl: db.lookup_arrays_semantics_foreign_inlj,

View File

@ -74,12 +74,13 @@ overview of the different EExpression types:
Provides the ability to define multiple variables in a local scope. They are particularly useful
when we want to reference some intermediate value multiple times.
- [ELocalLambda](https://github.com/mongodb/mongo/blob/06a931ffadd7ce62c32288d03e5a38933bd522d3/src/mongo/db/exec/sbe/expressions/expression.h#L487-L507)
Represents an anonymous function which takes a single input parameter. Many `EFunctions` accept
Represents an anonymous function which takes one or two input parameters. Many `EFunctions` accept
these as parameters. A good example of this is the [`traverseF`
function](https://github.com/mongodb/mongo/blob/06a931ffadd7ce62c32288d03e5a38933bd522d3/src/mongo/db/exec/sbe/vm/vm.cpp#L1329-L1357):
it accepts 2 parameters: an input and an `ELocalLambda`. If the input is an array, the
`ELocalLambda` is applied to each element in the array, otherwise, it is applied to the input on
its own.
its own. The second argument of the lambda receives the 0-based position of the element being examined;
when the `ELocalLambda` is being applied to the entire input, the second argument will have a value of -1.
EExpressions cannot be executed directly. Rather, [they are
compiled](https://github.com/mongodb/mongo/blob/06a931ffadd7ce62c32288d03e5a38933bd522d3/src/mongo/db/exec/sbe/expressions/expression.h#L81-L84)

View File

@ -31,7 +31,6 @@
#include "mongo/bson/ordering.h"
#include "mongo/db/exec/sbe/expressions/compile_ctx.h"
#include "mongo/db/exec/sbe/expressions/runtime_environment.h"
#include "mongo/db/exec/sbe/size_estimator.h"
#include "mongo/db/exec/sbe/stages/stages.h"
#include "mongo/db/exec/sbe/util/print_options.h"
@ -41,10 +40,10 @@
#include "mongo/db/exec/sbe/vm/vm_instruction.h"
#include "mongo/db/exec/sbe/vm/vm_types.h"
#include "mongo/db/query/datetime/date_time_support.h"
#include "mongo/stdx/unordered_map.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/str.h"
#include <algorithm>
#include <functional>
#include <sstream>
#include <vector>
@ -1263,6 +1262,7 @@ vm::CodeFragment generateTraverseP(CompileCtx& ctx, const EExpression::Vector& n
code.appendLabel(afterBodyLabel);
code.append(nodes[0]->compileDirect(ctx));
code.appendTraverseP(bodyPosition,
lambda->numArguments(),
tag == value::TypeTags::Nothing ? vm::Instruction::Nothing
: vm::Instruction::Int32One);
return code;
@ -1291,6 +1291,7 @@ vm::CodeFragment generateTraverseF(CompileCtx& ctx, const EExpression::Vector& n
code.appendLabel(afterBodyLabel);
code.append(nodes[0]->compileDirect(ctx));
code.appendTraverseF(bodyPosition,
lambda->numArguments(),
value::bitcastTo<bool>(val) ? vm::Instruction::True
: vm::Instruction::False);
return code;
@ -1707,18 +1708,17 @@ size_t ELocalBind::estimateSize() const {
}
std::unique_ptr<EExpression> ELocalLambda::clone() const {
return std::make_unique<ELocalLambda>(_frameId, _nodes.back()->clone());
return std::make_unique<ELocalLambda>(_frameId, _nodes.back()->clone(), _numArguments);
}
vm::CodeFragment ELocalLambda::compileBodyDirect(CompileCtx& ctx) const {
// Compile the body first so we know its size.
auto inner = _nodes.back()->compileDirect(ctx);
vm::CodeFragment body;
// Declare the frame containing lambda variable.
// The variable is expected to be already on the stack so declare the frame just below the
// The variables are expected to be already on the stack so declare the frame just below the
// current top of the stack.
body.declareFrame(_frameId, -1);
body.declareFrame(_frameId, -(int)_numArguments);
// Make sure the stack is sufficiently large.
body.appendAllocStack(inner.maxStackSize());
@ -1750,7 +1750,7 @@ vm::CodeFragment ELocalLambda::compileDirect(CompileCtx& ctx) const {
// Push the lambda value on the stack
code.appendLabel(afterBodyLabel);
code.appendLocalLambda(bodyPosition);
code.appendLocalLambda(bodyPosition, _numArguments);
return code;
});
@ -1761,7 +1761,9 @@ std::vector<DebugPrinter::Block> ELocalLambda::debugPrint() const {
DebugPrinter::addKeyword(ret, "lambda");
ret.emplace_back("`(`");
DebugPrinter::addIdentifier(ret, _frameId, 0);
for (size_t i = 0; i < _numArguments; i++) {
DebugPrinter::addIdentifier(ret, _frameId, i);
}
ret.emplace_back("`)");
ret.emplace_back("{");
DebugPrinter::addBlocks(ret, _nodes.back()->debugPrint());

View File

@ -599,13 +599,17 @@ private:
*/
class ELocalLambda final : public EExpression {
public:
ELocalLambda(FrameId frameId, std::unique_ptr<EExpression> body) : _frameId(frameId) {
ELocalLambda(FrameId frameId, std::unique_ptr<EExpression> body, size_t numArguments = 1)
: _frameId(frameId), _numArguments(numArguments) {
_nodes.emplace_back(std::move(body));
validateNodes();
}
std::unique_ptr<EExpression> clone() const override;
size_t numArguments() const {
return _numArguments;
}
vm::CodeFragment compileDirect(CompileCtx& ctx) const override;
vm::CodeFragment compileBodyDirect(CompileCtx& ctx) const;
std::vector<DebugPrinter::Block> debugPrint() const override;
@ -614,6 +618,7 @@ public:
private:
FrameId _frameId;
size_t _numArguments;
};
/**

View File

@ -27,9 +27,6 @@
* it in the license file.
*/
#include "mongo/base/string_data.h"
#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobj.h"
#include "mongo/db/exec/sbe/expression_test_base.h"
#include "mongo/db/exec/sbe/expressions/expression.h"
#include "mongo/db/exec/sbe/values/slot.h"
@ -37,7 +34,6 @@
#include "mongo/unittest/golden_test.h"
#include "mongo/unittest/unittest.h"
#include <memory>
#include <utility>
namespace mongo::sbe {
@ -69,6 +65,37 @@ TEST_F(SBELambdaTest, TraverseP_AddOneToArray) {
executeAndPrintVariation(os, *compiledExpr);
}
TEST_F(SBELambdaTest, TraverseP_AddOneToFirstArrayItem) {
auto& os = gctx->outStream();
value::ViewOfValueAccessor slotAccessor;
auto argSlot = bindAccessor(&slotAccessor);
FrameId frame = 10;
auto expr = sbe::makeE<sbe::EFunction>(
"traverseP",
sbe::makeEs(makeE<EVariable>(argSlot),
makeE<ELocalLambda>(frame,
makeE<EIf>(makeE<EPrimBinary>(EPrimBinary::Op::eq,
makeE<EVariable>(frame, 1),
makeC(makeInt64(0))),
makeE<EPrimBinary>(EPrimBinary::Op::add,
makeE<EVariable>(frame, 0),
makeC(makeInt32(1))),
makeE<EVariable>(frame, 0)),
2),
makeC(makeNothing())));
printInputExpression(os, *expr);
auto compiledExpr = compileExpression(*expr);
printCompiledExpression(os, *compiledExpr);
auto bsonArr = BSON_ARRAY(1 << 2 << 3);
slotAccessor.reset(value::TypeTags::bsonArray,
value::bitcastFrom<const char*>(bsonArr.objdata()));
executeAndPrintVariation(os, *compiledExpr);
}
TEST_F(SBELambdaTest, TraverseF_OpEq) {
auto& os = gctx->outStream();
@ -95,6 +122,37 @@ TEST_F(SBELambdaTest, TraverseF_OpEq) {
executeAndPrintVariation(os, *compiledExpr);
}
TEST_F(SBELambdaTest, TraverseF_OpEqFirstArrayItem) {
auto& os = gctx->outStream();
value::ViewOfValueAccessor slotAccessor;
auto argSlot = bindAccessor(&slotAccessor);
FrameId frame = 10;
auto expr = sbe::makeE<sbe::EFunction>(
"traverseF",
sbe::makeEs(
makeE<EVariable>(argSlot),
makeE<ELocalLambda>(frame,
makeE<EPrimBinary>(EPrimBinary::Op::logicAnd,
makeE<EPrimBinary>(EPrimBinary::Op::eq,
makeE<EVariable>(frame, 1),
makeC(makeInt64(0))),
makeE<EPrimBinary>(EPrimBinary::Op::eq,
makeE<EVariable>(frame, 0),
makeC(makeInt32(3)))),
2),
makeC(makeNothing())));
printInputExpression(os, *expr);
auto compiledExpr = compileExpression(*expr);
printCompiledExpression(os, *compiledExpr);
auto bsonArr = BSON_ARRAY(1 << 2 << 3 << 4);
slotAccessor.reset(value::TypeTags::bsonArray,
value::bitcastFrom<const char*>(bsonArr.objdata()));
executeAndPrintVariation(os, *compiledExpr);
}
TEST_F(SBELambdaTest, TraverseF_WithLocalBind) {
auto& os = gctx->outStream();

View File

@ -507,9 +507,9 @@ TEST(SBEVM, CodeFragmentPrintStable) {
code.appendFillEmpty(vm::Instruction::Null);
code.appendFillEmpty(vm::Instruction::False);
code.appendFillEmpty(vm::Instruction::True);
code.appendTraverseP(0xAA, vm::Instruction::Nothing);
code.appendTraverseP(0xAA, vm::Instruction::Int32One);
code.appendTraverseF(0xBB, vm::Instruction::True);
code.appendTraverseP(0xAA, 1, vm::Instruction::Nothing);
code.appendTraverseP(0xAA, 1, vm::Instruction::Int32One);
code.appendTraverseF(0xBB, 1, vm::Instruction::True);
code.appendGetField({}, "Hello world!"_sd);
code.appendAdd({}, {});

View File

@ -729,7 +729,8 @@ int getApproximateSize(TypeTags tag, Value val) {
case TypeTags::MinKey:
case TypeTags::MaxKey:
case TypeTags::bsonUndefined:
case TypeTags::LocalLambda:
case TypeTags::LocalOneArgLambda:
case TypeTags::LocalTwoArgLambda:
break;
// There are deep types.
case TypeTags::RecordId:

View File

@ -197,7 +197,8 @@ enum class TypeTags : uint8_t {
bsonCodeWScope,
// Local lambda value
LocalLambda,
LocalOneArgLambda,
LocalTwoArgLambda,
// The index key string.
keyString,

View File

@ -141,8 +141,11 @@ void ValuePrinter<T>::writeTagToStream(TypeTags tag) {
case TypeTags::bsonBinData:
stream << "bsonBinData";
break;
case TypeTags::LocalLambda:
stream << "LocalLambda";
case TypeTags::LocalOneArgLambda:
stream << "LocalOneArgLambda";
break;
case TypeTags::LocalTwoArgLambda:
stream << "LocalTwoArgLambda";
break;
case TypeTags::bsonUndefined:
stream << "bsonUndefined";
@ -534,8 +537,11 @@ void ValuePrinter<T>::writeValueToStream(TypeTags tag, Value val, size_t depth)
case TypeTags::bsonUndefined:
stream << "undefined";
break;
case TypeTags::LocalLambda:
stream << "LocalLambda";
case TypeTags::LocalOneArgLambda:
stream << "LocalOneArgLambda";
break;
case TypeTags::LocalTwoArgLambda:
stream << "LocalTwoArgLambda";
break;
case TypeTags::keyString: {
auto ks = getKeyString(val);

View File

@ -409,7 +409,7 @@ void CodeFragment::appendLocalVal(FrameId frameId, int variable, bool moveFrom)
// Compute the absolute variable stack offset based on the current stack depth
int stackOffset = varToOffset(variable) + _stackSize;
// If frame has stackPositiion defined, then compute the final relative stack offset.
// If frame has stackPosition defined, then compute the final relative stack offset.
// Otherwise, register a fixup to compute the relative stack offset later.
if (frame.stackPosition != FrameInfo::kPositionNotSet) {
stackOffset -= frame.stackPosition;
@ -426,9 +426,10 @@ void CodeFragment::appendLocalVal(FrameId frameId, int variable, bool moveFrom)
adjustStackSimple(i);
}
void CodeFragment::appendLocalLambda(int codePosition) {
void CodeFragment::appendLocalLambda(int codePosition, size_t numArgs) {
invariant(numArgs == 1 || numArgs == 2);
Instruction i;
i.tag = Instruction::pushLocalLambda;
i.tag = numArgs == 1 ? Instruction::pushOneArgLambda : Instruction::pushTwoArgLambda;
auto size = sizeof(Instruction) + sizeof(codePosition);
auto offset = allocateSpace(size);
@ -744,16 +745,19 @@ void CodeFragment::appendTraverseP() {
appendSimpleInstruction(Instruction::traverseP);
}
void CodeFragment::appendTraverseP(int codePosition, Instruction::Constants k) {
void CodeFragment::appendTraverseP(int codePosition, size_t numArgs, Instruction::Constants k) {
Instruction i;
i.tag = Instruction::traversePImm;
auto size = sizeof(Instruction) + sizeof(codePosition) + sizeof(k);
auto size =
sizeof(Instruction) + sizeof(codePosition) + sizeof(Instruction::Constants) + sizeof(k);
auto offset = allocateSpace(size);
int codeOffset = codePosition - _instrs.size();
offset += writeToMemory(offset, i);
offset += writeToMemory(
offset, numArgs == 2 ? Instruction::Constants::True : Instruction::Constants::False);
offset += writeToMemory(offset, k);
offset += writeToMemory(offset, codeOffset);
@ -767,16 +771,19 @@ void CodeFragment::appendTraverseF() {
appendSimpleInstruction(Instruction::traverseF);
}
void CodeFragment::appendTraverseF(int codePosition, Instruction::Constants k) {
void CodeFragment::appendTraverseF(int codePosition, size_t numArgs, Instruction::Constants k) {
Instruction i;
i.tag = Instruction::traverseFImm;
auto size = sizeof(Instruction) + sizeof(codePosition) + sizeof(k);
auto size =
sizeof(Instruction) + sizeof(codePosition) + sizeof(Instruction::Constants) + sizeof(k);
auto offset = allocateSpace(size);
int codeOffset = codePosition - _instrs.size();
offset += writeToMemory(offset, i);
offset += writeToMemory(
offset, numArgs == 2 ? Instruction::Constants::True : Instruction::Constants::False);
offset += writeToMemory(offset, k);
offset += writeToMemory(offset, codeOffset);

View File

@ -70,7 +70,7 @@ public:
void appendAccessVal(value::SlotAccessor* accessor);
void appendMoveVal(value::SlotAccessor* accessor);
void appendLocalVal(FrameId frameId, int variable, bool moveFrom);
void appendLocalLambda(int codePosition);
void appendLocalLambda(int codePosition, size_t numArgs);
void appendPop();
void appendSwap();
void appendMakeOwn(Instruction::Parameter arg);
@ -126,9 +126,9 @@ public:
void appendCollComparisonKey(Instruction::Parameter lhs, Instruction::Parameter rhs);
void appendGetFieldOrElement(Instruction::Parameter lhs, Instruction::Parameter rhs);
void appendTraverseP();
void appendTraverseP(int codePosition, Instruction::Constants k);
void appendTraverseP(int codePosition, size_t numArgs, Instruction::Constants k);
void appendTraverseF();
void appendTraverseF(int codePosition, Instruction::Constants k);
void appendTraverseF(int codePosition, size_t numArgs, Instruction::Constants k);
void appendMagicTraverseF();
void appendSetField();
void appendGetArraySize(Instruction::Parameter input);

View File

@ -220,7 +220,8 @@ void ByteCode::traverseP(const CodeFragment* code) {
popAndReleaseStack();
if ((maxDepthTag != value::TypeTags::Nothing && maxDepthTag != value::TypeTags::NumberInt32) ||
lamTag != value::TypeTags::LocalLambda) {
(lamTag != value::TypeTags::LocalOneArgLambda &&
lamTag != value::TypeTags::LocalTwoArgLambda)) {
popAndReleaseStack();
pushStack(false, value::TypeTags::Nothing, 0);
return;
@ -231,10 +232,13 @@ void ByteCode::traverseP(const CodeFragment* code) {
? value::bitcastTo<int32_t>(maxDepthVal)
: std::numeric_limits<int64_t>::max();
traverseP(code, lamPos, maxDepth);
traverseP(code, lamPos, lamTag == value::TypeTags::LocalTwoArgLambda, maxDepth);
}
void ByteCode::traverseP(const CodeFragment* code, int64_t position, int64_t maxDepth) {
void ByteCode::traverseP(const CodeFragment* code,
int64_t position,
bool providePosition,
int64_t maxDepth) {
auto [own, tag, val] = getFromStack(0);
if (value::isArray(tag) && maxDepth > 0) {
@ -245,9 +249,19 @@ void ByteCode::traverseP(const CodeFragment* code, int64_t position, int64_t max
--maxDepth;
}
traverseP_nested(code, position, tag, val, maxDepth);
traverseP_nested(code, position, tag, val, providePosition, maxDepth, 0);
} else {
if (providePosition) {
// Push a -1 on the stack (to indicate that we are not iterating over an array) before
// the value to be processed
pushStack(false, value::TypeTags::NumberInt64, value::bitcastTo<int64_t>(-1));
}
runLambdaInternal(code, position);
if (providePosition) {
// Remove the position from the stack.
swapStack();
popStack();
}
}
}
@ -255,7 +269,9 @@ void ByteCode::traverseP_nested(const CodeFragment* code,
int64_t position,
value::TypeTags tagInput,
value::Value valInput,
int64_t maxDepth) {
bool providePosition,
int64_t maxDepth,
int64_t curDepth) {
auto decrement = [](int64_t d) {
return d == std::numeric_limits<int64_t>::max() ? d : d - 1;
};
@ -263,14 +279,31 @@ void ByteCode::traverseP_nested(const CodeFragment* code,
auto [tagArrOutput, valArrOutput] = value::makeNewArray();
auto arrOutput = value::getArrayView(valArrOutput);
value::ValueGuard guard{tagArrOutput, valArrOutput};
// The array position we will be sending to the lambda holds the depth in the highest 32 bit,
// and the actual index in the current array in the lowest 32 bit.
size_t arrayPos = curDepth << 32;
value::arrayForEach(tagInput, valInput, [&](value::TypeTags elemTag, value::Value elemVal) {
if (maxDepth > 0 && value::isArray(elemTag)) {
traverseP_nested(code, position, elemTag, elemVal, decrement(maxDepth));
traverseP_nested(code,
position,
elemTag,
elemVal,
providePosition,
decrement(maxDepth),
curDepth + 1);
} else {
pushStack(false, elemTag, elemVal);
if (providePosition) {
pushStack(false, value::TypeTags::NumberInt64, value::bitcastTo<int64_t>(arrayPos));
}
runLambdaInternal(code, position);
if (providePosition) {
// Remove the position from the stack.
swapStack();
popStack();
}
}
arrayPos++;
auto [retOwn, retTag, retVal] = getFromStack(0);
popStack();
@ -377,7 +410,8 @@ void ByteCode::traverseF(const CodeFragment* code) {
auto [lamOwn, lamTag, lamVal] = getFromStack(0);
popAndReleaseStack();
if (lamTag != value::TypeTags::LocalLambda) {
if (lamTag != value::TypeTags::LocalOneArgLambda &&
lamTag != value::TypeTags::LocalTwoArgLambda) {
popAndReleaseStack();
pushStack(false, value::TypeTags::Nothing, 0);
return;
@ -386,16 +420,29 @@ void ByteCode::traverseF(const CodeFragment* code) {
bool compareArray = numberTag == value::TypeTags::Boolean && value::bitcastTo<bool>(numberVal);
traverseF(code, lamPos, compareArray);
traverseF(code, lamPos, lamTag == value::TypeTags::LocalTwoArgLambda, compareArray);
}
void ByteCode::traverseF(const CodeFragment* code, int64_t position, bool compareArray) {
void ByteCode::traverseF(const CodeFragment* code,
int64_t position,
bool providePosition,
bool compareArray) {
auto [ownInput, tagInput, valInput] = getFromStack(0);
if (value::isArray(tagInput)) {
traverseFInArray(code, position, compareArray);
traverseFInArray(code, position, providePosition, compareArray);
} else {
// Push a -1 on the stack (to indicate that we are not iterating over an array) before the
// value to be processed
if (providePosition) {
pushStack(false, value::TypeTags::NumberInt64, value::bitcastTo<int64_t>(-1));
}
runLambdaInternal(code, position);
// Remove the item position from the stack.
if (providePosition) {
swapStack();
popStack();
}
}
}
@ -411,19 +458,33 @@ bool ByteCode::runLambdaPredicate(const CodeFragment* code, int64_t position) {
return isTrue;
}
void ByteCode::traverseFInArray(const CodeFragment* code, int64_t position, bool compareArray) {
void ByteCode::traverseFInArray(const CodeFragment* code,
int64_t position,
bool providePosition,
bool compareArray) {
auto [ownInput, tagInput, valInput] = getFromStack(0);
value::ValueGuard input(ownInput, tagInput, valInput);
popStack();
size_t arrayPos = 0;
const bool passed =
value::arrayAny(tagInput, valInput, [&](value::TypeTags tag, value::Value val) {
pushStack(false, tag, val);
if (providePosition) {
pushStack(
false, value::TypeTags::NumberInt64, value::bitcastTo<int64_t>(arrayPos++));
}
if (runLambdaPredicate(code, position)) {
if (providePosition) {
popStack();
}
pushStack(false, value::TypeTags::Boolean, value::bitcastFrom<bool>(true));
return true;
}
if (providePosition) {
popStack();
}
return false;
});
@ -436,8 +497,15 @@ void ByteCode::traverseFInArray(const CodeFragment* code, int64_t position, bool
if (compareArray) {
// Transfer the ownership to the lambda
pushStack(ownInput, tagInput, valInput);
if (providePosition) {
pushStack(false, value::TypeTags::NumberInt64, value::bitcastTo<int64_t>(-1));
}
input.reset();
runLambdaInternal(code, position);
if (providePosition) {
swapStack();
popStack();
}
return;
}

View File

@ -559,16 +559,27 @@ private:
value::Value fieldValue);
void traverseP(const CodeFragment* code);
void traverseP(const CodeFragment* code, int64_t position, int64_t maxDepth);
void traverseP(const CodeFragment* code,
int64_t position,
bool providePosition,
int64_t maxDepth);
void traverseP_nested(const CodeFragment* code,
int64_t position,
value::TypeTags tag,
value::Value val,
int64_t maxDepth);
bool providePosition,
int64_t maxDepth,
int64_t curDepth);
void traverseF(const CodeFragment* code);
void traverseF(const CodeFragment* code, int64_t position, bool compareArray);
void traverseFInArray(const CodeFragment* code, int64_t position, bool compareArray);
void traverseF(const CodeFragment* code,
int64_t position,
bool providePosition,
bool compareArray);
void traverseFInArray(const CodeFragment* code,
int64_t position,
bool providePosition,
bool compareArray);
void magicTraverseF(const CodeFragment* code);
bool runLambdaPredicate(const CodeFragment* code, int64_t position);

View File

@ -2178,7 +2178,7 @@ void ByteCode::valueBlockApplyLambda(const CodeFragment* code) {
popAndReleaseStack();
value::ValueGuard maskGuard(maskOwn, maskTag, maskVal);
if (lamTag != value::TypeTags::LocalLambda) {
if (lamTag != value::TypeTags::LocalOneArgLambda) {
pushStack(false, value::TypeTags::Nothing, 0);
return;
}

View File

@ -86,7 +86,8 @@ int Instruction::stackOffset[Instruction::Tags::lastInstruction] = {
1, // pushMoveVal
1, // pushLocalVal
1, // pushMoveLocalVal
1, // pushLocalLambda
1, // pushOneArgLambda
1, // pushTwoArgLambda
-1, // pop
0, // swap
0, // makeOwn
@ -322,13 +323,22 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
pushStack(owned, tag, val);
break;
}
case Instruction::pushLocalLambda: {
case Instruction::pushOneArgLambda: {
auto offset = readFromMemory<int>(pcPointer);
pcPointer += sizeof(offset);
auto newPosition = pcPointer - code->instrs().data() + offset;
pushStack(
false, value::TypeTags::LocalLambda, value::bitcastFrom<int64_t>(newPosition));
pushStack(false,
value::TypeTags::LocalOneArgLambda,
value::bitcastFrom<int64_t>(newPosition));
break;
}
case Instruction::pushTwoArgLambda: {
auto offset = readFromMemory<int>(pcPointer);
pcPointer += sizeof(offset);
auto newPosition = pcPointer - code->instrs().data() + offset;
pushStack(false,
value::TypeTags::LocalTwoArgLambda,
value::bitcastFrom<int64_t>(newPosition));
break;
}
case Instruction::pop: {
@ -965,6 +975,8 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
break;
}
case Instruction::traversePImm: {
auto providePosition = readFromMemory<Instruction::Constants>(pcPointer);
pcPointer += sizeof(providePosition);
auto k = readFromMemory<Instruction::Constants>(pcPointer);
pcPointer += sizeof(k);
@ -974,6 +986,7 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
traverseP(code,
codePosition,
providePosition == Instruction::True ? true : false,
k == Instruction::Nothing ? std::numeric_limits<int64_t>::max() : 1);
break;
@ -983,6 +996,8 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
break;
}
case Instruction::traverseFImm: {
auto providePosition = readFromMemory<Instruction::Constants>(pcPointer);
pcPointer += sizeof(providePosition);
auto k = readFromMemory<Instruction::Constants>(pcPointer);
pcPointer += sizeof(k);
@ -990,7 +1005,10 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
pcPointer += sizeof(offset);
auto codePosition = pcPointer - code->instrs().data() + offset;
traverseF(code, codePosition, k == Instruction::True ? true : false);
traverseF(code,
codePosition,
providePosition == Instruction::True ? true : false,
k == Instruction::True ? true : false);
break;
}
@ -1474,8 +1492,10 @@ const char* Instruction::toString() const {
return "pushLocalVal";
case pushMoveLocalVal:
return "pushMoveLocalVal";
case pushLocalLambda:
return "pushLocalLambda";
case pushOneArgLambda:
return "pushOneArgLambda";
case pushTwoArgLambda:
return "pushTwoArgLambda";
case pop:
return "pop";
case swap:

View File

@ -57,7 +57,8 @@ struct Instruction {
pushMoveVal,
pushLocalVal,
pushMoveLocalVal,
pushLocalLambda,
pushOneArgLambda,
pushTwoArgLambda,
pop,
swap,
makeOwn,

View File

@ -356,7 +356,9 @@ private:
size_t argIdx = lambdaArg.argIdx;
auto [_, lamTag, lamVal] = getArg(argIdx);
tassert(7103506, "Expected arg to be LocalLambda", lamTag == value::TypeTags::LocalLambda);
tassert(7103506,
"Expected arg to be LocalLambda",
lamTag == value::TypeTags::LocalOneArgLambda);
int64_t lamPos = value::bitcastTo<int64_t>(lamVal);
auto [outputOwned, outputTag, outputVal] = invokeLambda(lamPos, tag, val);

View File

@ -254,7 +254,8 @@ public:
break;
}
// Instructions with a single integer argument.
case Instruction::pushLocalLambda: {
case Instruction::pushOneArgLambda:
case Instruction::pushTwoArgLambda: {
auto offset = readFromMemory<int>(pcPointer);
pcPointer += sizeof(offset);
os << "target: " << _formatter.pcPointer(pcPointer + offset);
@ -280,11 +281,14 @@ public:
// Instructions with other kinds of arguments.
case Instruction::traversePImm:
case Instruction::traverseFImm: {
auto providePosition = readFromMemory<Instruction::Constants>(pcPointer);
pcPointer += sizeof(providePosition);
auto k = readFromMemory<Instruction::Constants>(pcPointer);
pcPointer += sizeof(k);
auto offset = readFromMemory<int>(pcPointer);
pcPointer += sizeof(offset);
os << "k: " << Instruction::toStringConstants(k)
os << "providePosition: " << Instruction::toStringConstants(providePosition)
<< ", k: " << Instruction::toStringConstants(k)
<< ", target: " << _formatter.pcPointer(pcPointer + offset);
break;
}

View File

@ -912,31 +912,24 @@ public:
const LambdaAbstraction& expr,
ExplainPrinter inResult) {
ExplainPrinter printer("LambdaAbstraction");
printer.separator(" [")
.fieldName("variable", ExplainVersion::V3)
.print(expr.varName())
.separator("]")
printer.separator(" [");
auto numVars = expr.varNames().size();
for (size_t idx = 0; idx < numVars; ++idx) {
std::stringstream ss;
ss << "variable" << idx;
printer.fieldName(ss.str(), ExplainVersion::V3).print(expr.varNames()[idx]);
if (idx < numVars - 1) {
printer.separator(", ");
}
}
printer.separator("]")
.setChildCount(1)
.fieldName("input", ExplainVersion::V3)
.print(inResult);
return printer;
}
ExplainPrinter transport(const ABT::reference_type /*n*/,
const LambdaApplication& expr,
ExplainPrinter lambdaResult,
ExplainPrinter argumentResult) {
ExplainPrinter printer("LambdaApplication");
printer.separator(" []")
.setChildCount(2)
.maybeReverse()
.fieldName("lambda", ExplainVersion::V3)
.print(lambdaResult)
.fieldName("argument", ExplainVersion::V3)
.print(argumentResult);
return printer;
}
ExplainPrinter transport(const ABT::reference_type /*n*/,
const FunctionCall& expr,
std::vector<ExplainPrinter> argResults) {

View File

@ -130,7 +130,9 @@ public:
}
void transport(const LambdaAbstraction& op, const ABT& /*bind*/) {
_variableDefinitionCallback(op.varName());
for (auto& var : op.varNames()) {
_variableDefinitionCallback(var);
}
}
void transport(const Let& op, const ABT& /*bind*/, const ABT& /*expr*/) {
@ -208,7 +210,9 @@ struct Collector {
CollectedInfo result{collectorState};
// resolve any free variables manually.
inResult.resolveFreeVars(lam.varName(), Definition{n.ref(), ABT::reference_type{}});
for (auto& var : lam.varNames()) {
inResult.resolveFreeVars(var, Definition{n.ref(), ABT::reference_type{}});
}
result.merge(std::move(inResult));
return result;
@ -378,7 +382,9 @@ struct LastRefsTransporter {
Result transport(const ABT& n, const LambdaAbstraction& lam, Result inResult) {
// As in the Let case, we can finalize the last ref for the local variable.
finalizeLastRefs(inResult, lam.varName());
for (auto& var : lam.varNames()) {
finalizeLastRefs(inResult, var);
}
return inResult;
}

View File

@ -495,27 +495,34 @@ public:
};
/**
* Represents a single argument lambda. Defines a local variable (representing the argument) which
* can be referenced within the lambda. The variable takes on the value to which LambdaAbstraction
* is applied by its parent.
* Represents a lambda with either one or two input arguments. Defines local variables (representing
* the arguments) which can be referenced within the lambda. The variables take on the values to
* which LambdaAbstraction is applied by its parent.
*/
class LambdaAbstraction final : public ABTOpFixedArity<1>, public ExpressionSyntaxSort {
using Base = ABTOpFixedArity<1>;
ProjectionName _varName;
ProjectionNameVector _varNames;
public:
LambdaAbstraction(ProjectionName var, ABT inBody)
: Base(std::move(inBody)), _varName(std::move(var)) {
LambdaAbstraction(ProjectionName var, ABT inBody) : Base(std::move(inBody)) {
_varNames.emplace_back(std::move(var));
assertExprSort(getBody());
}
LambdaAbstraction(ProjectionName var1, ProjectionName var2, ABT inBody)
: Base(std::move(inBody)) {
_varNames.emplace_back(std::move(var1));
_varNames.emplace_back(std::move(var2));
assertExprSort(getBody());
}
bool operator==(const LambdaAbstraction& other) const {
return _varName == other._varName && getBody() == other.getBody();
return _varNames == other._varNames && getBody() == other.getBody();
}
auto& varName() const {
return _varName;
auto& varNames() const {
return _varNames;
}
const ABT& getBody() const {
@ -527,41 +534,6 @@ public:
}
};
/**
* Evaluates an expression representing a function over an expression representing the argument to
* the function.
*/
class LambdaApplication final : public ABTOpFixedArity<2>, public ExpressionSyntaxSort {
using Base = ABTOpFixedArity<2>;
public:
LambdaApplication(ABT inLambda, ABT inArgument)
: Base(std::move(inLambda), std::move(inArgument)) {
assertExprSort(getLambda());
assertExprSort(getArgument());
}
bool operator==(const LambdaApplication& other) const {
return getLambda() == other.getLambda() && getArgument() == other.getArgument();
}
const ABT& getLambda() const {
return get<0>();
}
ABT& getLambda() {
return get<0>();
}
const ABT& getArgument() const {
return get<1>();
}
ABT& getArgument() {
return get<1>();
}
};
/**
* Dynamic arity operator which passes its children as arguments to a function specified by SBE
* function expression name.

View File

@ -70,7 +70,6 @@ using ABT = algebra::PolyValue<Blackhole,
Let,
MultiLet,
LambdaAbstraction,
LambdaApplication,
FunctionCall,
Source,
Switch,

View File

@ -44,7 +44,6 @@ class If;
class Let;
class MultiLet;
class LambdaAbstraction;
class LambdaApplication;
class FunctionCall;
class Source;
class Switch;

View File

@ -153,16 +153,7 @@ std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
auto frameId = it->second;
_lambdaMap.erase(it);
return sbe::makeE<sbe::ELocalLambda>(frameId, std::move(body));
}
std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
const LambdaApplication&,
std::unique_ptr<sbe::EExpression> lam,
std::unique_ptr<sbe::EExpression> arg) {
// lambda applications are not directly supported by SBE (yet) and must not be present.
tasserted(6624208, "lambda application is not implemented");
return nullptr;
return sbe::makeE<sbe::ELocalLambda>(frameId, std::move(body), lam.varNames().size());
}
std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Variable& var) {
@ -193,8 +184,13 @@ std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(const Variabl
// than a slot.
auto it = _lambdaMap.find(lam);
tassert(6624204, "incorrect lambda map", it != _lambdaMap.end());
const ProjectionNameVector& varNames = lam->varNames();
auto itVar = std::find(varNames.begin(), varNames.end(), var.name());
tassert(
10800801, "variable is not defined in associated lambda", itVar != varNames.end());
return sbe::makeE<sbe::EVariable>(it->second, 0, _env.isLastRef(var));
return sbe::makeE<sbe::EVariable>(
it->second, std::distance(varNames.begin(), itVar), _env.isLastRef(var));
}
}

View File

@ -111,9 +111,6 @@ public:
void prepare(const LambdaAbstraction& lam);
std::unique_ptr<sbe::EExpression> transport(const LambdaAbstraction& lam,
std::unique_ptr<sbe::EExpression> body);
std::unique_ptr<sbe::EExpression> transport(const LambdaApplication&,
std::unique_ptr<sbe::EExpression> lam,
std::unique_ptr<sbe::EExpression> arg);
std::unique_ptr<sbe::EExpression> transport(
const FunctionCall& fn, std::vector<std::unique_ptr<sbe::EExpression>> args);

View File

@ -213,23 +213,6 @@ void ExpressionConstEval::transport(abt::ABT& n,
}
}
void ExpressionConstEval::transport(abt::ABT& n,
const abt::LambdaApplication& app,
abt::ABT& lam,
abt::ABT& arg) {
// If the 'lam' expression is abt::LambdaAbstraction then we can do the inplace beta
// reduction.
// TODO - missing alpha conversion so for now assume globally unique names.
if (auto lambda = lam.cast<abt::LambdaAbstraction>(); lambda) {
auto result =
abt::make<abt::Let>(lambda->varName(),
std::exchange(arg, abt::make<abt::Blackhole>()),
std::exchange(lambda->getBody(), abt::make<abt::Blackhole>()));
swapAndUpdate(n, std::move(result));
}
}
void ExpressionConstEval::transport(abt::ABT& n, const abt::UnaryOp& op, abt::ABT& child) {
switch (op.op()) {
case abt::Operations::Not: {
@ -741,13 +724,16 @@ void ExpressionConstEval::transport(abt::ABT& n,
void ExpressionConstEval::prepare(abt::ABT& n, const abt::LambdaAbstraction& lam) {
++_inCostlyCtx;
_variableDefinitions.emplace(lam.varName(),
abt::Definition{n.ref(), abt::ABT::reference_type{}});
for (auto& var : lam.varNames()) {
_variableDefinitions.emplace(var, abt::Definition{n.ref(), abt::ABT::reference_type{}});
}
}
void ExpressionConstEval::transport(abt::ABT&, const abt::LambdaAbstraction& lam, abt::ABT&) {
--_inCostlyCtx;
_variableDefinitions.erase(lam.varName());
for (auto& var : lam.varNames()) {
_variableDefinitions.erase(var);
}
}
void ExpressionConstEval::prepare(abt::ABT&, const abt::References& refs) {

View File

@ -64,7 +64,6 @@ public:
void transport(abt::ABT& n, const abt::Let& let, abt::ABT&, abt::ABT& in);
void prepare(abt::ABT&, const abt::MultiLet& multiLet);
void transport(abt::ABT& n, const abt::MultiLet& multiLet, std::vector<abt::ABT>& args);
void transport(abt::ABT& n, const abt::LambdaApplication& app, abt::ABT& lam, abt::ABT& arg);
void prepare(abt::ABT&, const abt::LambdaAbstraction&);
void transport(abt::ABT&, const abt::LambdaAbstraction&, abt::ABT&);

View File

@ -29,11 +29,9 @@
#include "mongo/base/string_data.h"
#include "mongo/bson/bsonobj.h"
#include "mongo/bson/bsontypes.h"
#include "mongo/bson/ordering.h"
#include "mongo/db/curop.h"
#include "mongo/db/exec/sbe/expressions/expression.h"
#include "mongo/db/exec/sbe/stages/stages.h"
#include "mongo/db/exec/sbe/values/slot.h"
#include "mongo/db/exec/sbe/values/value.h"
@ -50,7 +48,6 @@
#include "mongo/db/query/compiler/metadata/index_entry.h"
#include "mongo/db/query/compiler/physical_model/query_solution/query_solution.h"
#include "mongo/db/query/compiler/physical_model/query_solution/stage_types.h"
#include "mongo/db/query/index_hint.h"
#include "mongo/db/query/multiple_collection_accessor.h"
#include "mongo/db/query/query_knobs_gen.h"
#include "mongo/db/query/stage_builder/sbe/builder.h"
@ -61,15 +58,11 @@
#include "mongo/db/query/util/make_data_structure.h"
#include "mongo/db/storage/key_string/key_string.h"
#include "mongo/db/storage/sorted_data_interface.h"
#include "mongo/platform/atomic_word.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/str.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
@ -158,13 +151,13 @@ namespace {
* {a: [{b: [1, 2]}, 3]} // a.0.b.0._, a.0.b.1._ and a.1.b._ end in scalar values inside arrays
*/
// Creates an expression for traversing path 'fp' in the record from 'inputSlot' that implement MQL
// Creates an expression for traversing path 'fp' in the record from 'inputExpr' that implement MQL
// semantics for local collections. The semantics never treat terminal arrays as whole values and
// match to null per "Matching local records to null" above. Returns all the key values in a single
// array. For example, if the record in the 'inputSlot' is:
// array. For example, if the record in the 'inputExpr' is:
// {a: [{b:[1,[2,3]]}, {b:4}, {b:1}, {b:2}]},
// the returned values for path "a.b" will be packed as: [1, [2,3], 4, 1, 2].
// Empty arrays and missing are skipped, that is, if the record in the 'inputSlot' is:
// Empty arrays and missing are skipped, that is, if the record in the 'inputExpr' is:
// {a: [{b:1}, {b:[]}, {no_b:42}, {b:2}]},
// the returned values for path "a.b" will be packed as: [1, 2].
SbExpr generateLocalKeyStream(SbExpr inputExpr,
@ -182,13 +175,14 @@ SbExpr generateLocalKeyStream(SbExpr inputExpr,
!inputExpr.isNull() || topLevelFieldSlot.has_value());
// Generate an expression to read a sub-field at the current nested level.
SbExpr fieldName = b.makeStrConstant(fp.getFieldName(level));
SbExpr fieldExpr = topLevelFieldSlot
? SbExpr{*topLevelFieldSlot}
: b.makeFunction("getField"_sd, std::move(inputExpr), std::move(fieldName));
: b.makeFunction(
"getField"_sd, std::move(inputExpr), b.makeStrConstant(fp.getFieldName(level)));
if (level == fp.getPathLength() - 1) {
// The last level doesn't expand leaf arrays.
// In the generation of the local keys, the last level doesn't
// expand leaf arrays.
return fieldExpr;
}
@ -226,7 +220,7 @@ SbExpr generateLocalKeyStream(SbExpr inputExpr,
SbLocalVar{traverseFrameId, 1}));
}
// Creates stages for traversing path 'fp' in the record from 'inputSlot' that implement MQL
// Creates an expression for traversing path 'fp' in the record from 'inputSlot' that implement MQL
// semantics for foreign collections. Returns one key value at a time, including terminal arrays as
// a whole value. For example,
// if the record in the 'inputSlot' is:
@ -238,127 +232,95 @@ SbExpr generateLocalKeyStream(SbExpr inputExpr,
// Replaces other missing terminals with 'null', that is, if the record in the 'inputSlot' is:
// {a: [{b:1}, {b:[]}, {no_b:42}, {b:2}]},
// the returned values for path "a.b" will be streamed as: 1, [], null, 2.
std::pair<SbSlot /* keyValueSlot */, SbStage> buildForeignKeysStream(SbSlot inputSlot,
const FieldPath& fp,
const PlanNodeId nodeId,
StageBuilderState& state) {
SbBuilder b(state, nodeId);
SbExpr generateForeignKeyStream(SbVar inputSlot,
boost::optional<SbVar> arrayPosSlot,
const FieldPath& fp,
size_t level,
StageBuilderState& state,
boost::optional<SbSlot> topLevelFieldSlot = boost::none) {
using namespace std::literals;
const FieldIndex numParts = fp.getPathLength();
SbExprBuilder b(state);
invariant(level < fp.getPathLength());
SbSlot keyValueSlot = inputSlot;
SbSlot prevKeyValueSlot = inputSlot;
SbStage currentStage = b.makeLimitOneCoScanTree();
for (size_t i = 0; i < numParts; i++) {
const StringData fieldName = fp.getFieldName(i);
SbExpr getFieldFromObject;
if (i == 0) {
// 'inputSlot' must contain a document and, by definition, it's not inside an array, so
// can get field unconditionally.
getFieldFromObject = b.makeFillEmptyNull(
b.makeFunction("getField"_sd, keyValueSlot, b.makeStrConstant(fieldName)));
// Generate an expression to read a sub-field at the current nested level.
SbExpr getFieldFromObject;
if (level == 0) {
if (topLevelFieldSlot) {
// the first navigated path is already available in the 'topLevelFieldSlot'.
getFieldFromObject = b.makeFillEmptyNull(topLevelFieldSlot);
} else {
// Don't get field from scalars inside arrays (it would fail but we also don't want to
// fill with "null" in this case to match the MQL semantics described above.)
SbExpr shouldGetField =
b.makeBooleanOpTree(abt::Operations::Or,
b.makeFunction("isObject", keyValueSlot),
b.makeNot(b.makeFunction("isArray", prevKeyValueSlot)));
getFieldFromObject =
b.makeIf(std::move(shouldGetField),
b.makeFillEmptyNull(b.makeFunction(
"getField"_sd, keyValueSlot, b.makeStrConstant(fieldName))),
b.makeNothingConstant());
getFieldFromObject = b.makeFillEmptyNull(b.makeFunction(
"getField"_sd, inputSlot, b.makeStrConstant(fp.getFieldName(level))));
}
} else {
tassert(10800800, "arrayPosSlot must be provided", arrayPosSlot);
// Don't get field from scalars inside arrays (it would fail but we also don't want to
// fill with "null" in this case to match the MQL semantics described above): this is
// achieved by checking that the position in the parent array exposed in the arrayPosSlot
// variable is set to -1, i.e. we are not iterating over an array at all.
SbExpr shouldGetField = b.makeBooleanOpTree(
abt::Operations::Or,
b.makeBinaryOp(abt::Operations::Eq, *arrayPosSlot, b.makeInt64Constant(-1)),
b.makeFunction("isObject"_sd, inputSlot));
auto [outStage, outSlots] =
b.makeProject(std::move(currentStage), std::move(getFieldFromObject));
currentStage = std::move(outStage);
SbSlot getFieldSlot = outSlots[0];
keyValueSlot = getFieldSlot;
// For the terminal array we will do the extra work of adding the array itself to the stream
// (see below) but for the non-terminal path components we only need to unwind array
// elements.
if (i + 1 < numParts) {
constexpr bool preserveNullAndEmptyArrays = true;
auto [outStage, unwindOutputSlot, _] =
b.makeUnwind(std::move(currentStage), keyValueSlot, preserveNullAndEmptyArrays);
currentStage = std::move(outStage);
prevKeyValueSlot = keyValueSlot;
keyValueSlot = unwindOutputSlot;
}
getFieldFromObject =
b.makeIf(std::move(shouldGetField),
b.makeFillEmptyNull(b.makeFunction(
"getField"_sd, inputSlot, b.makeStrConstant(fp.getFieldName(level)))),
b.makeNothingConstant());
}
// For the terminal field part, both the array elements and the array itself are considered as
// keys. To implement this, we use a "union" stage, where the first branch produces array
// elements and the second branch produces the array itself. To avoid re-traversing the path, we
// pass the already traversed path to the "union" via "nlj" stage. However, for scalars 'unwind'
// produces the scalar itself and we don't want to add it to the stream twice -- this is handled
// by the 'branch' stage.
// For example, for foreignField = "a.b" this part of the tree would look like:
// [2] nlj [] [s17]
// left
// # Get the terminal value on the path, it will be placed in s17, it might be a scalar
// # or it might be an array.
// [2] project [s17 = if (
// isObject (s15) || ! isArray (s14), fillEmpty (getField (s15, "b"), null),
// Nothing)]
// [2] unwind s15 s16 s14 true
// [2] project [s14 = fillEmpty (getField (s7 = inputSlot, "a"), null)]
// [2] limit 1
// [2] coscan
// right
// # Process the terminal value depending on whether it's an array or a scalar/object.
// [2] branch {isArray (s17)} [s21]
// # If s17 is an array, unwind it and union with the value of the array itself.
// [s20] [2] union [s20] [
// [s18] [2] unwind s18 s19 s17 true
// [2] limit 1
// [2] coscan ,
// [s17] [2] limit 1
// [2] coscan
// ]
// # If s17 isn't an array, don't need to do anything and simply return s17.
// [s17] [2] limit 1
// [2] coscan
constexpr bool preserveNullAndEmptyArrays = true;
if (level == fp.getPathLength() - 1) {
// For the terminal field part, both the array elements and the array itself are considered
// as keys.
sbe::FrameId traverseFrameId = state.frameId();
return b.makeLet(
traverseFrameId,
SbExpr::makeSeq(std::move(getFieldFromObject)),
b.makeIf(b.makeFunction("isArray"_sd, SbLocalVar{traverseFrameId, 0}),
b.makeFunction("concatArrays"_sd,
SbLocalVar{traverseFrameId, 0},
b.makeFunction("newArray"_sd, SbLocalVar{traverseFrameId, 0})),
b.makeFunction("newArray"_sd, SbLocalVar{traverseFrameId, 0})));
}
auto [terminalUnwind, terminalUnwindOutputSlot, _] =
b.makeUnwind(b.makeLimitOneCoScanTree(), keyValueSlot, preserveNullAndEmptyArrays);
// Generate nested traversal.
sbe::FrameId lambdaForArrayFrameId = state.frameId();
sbe::PlanStage::Vector terminalStagesToUnion;
terminalStagesToUnion.push_back(std::move(terminalUnwind));
terminalStagesToUnion.emplace_back(b.makeLimitOneCoScanTree());
SbExpr lambdaForArrayExpr =
b.makeLocalLambda2(lambdaForArrayFrameId,
generateForeignKeyStream(SbLocalVar{lambdaForArrayFrameId, 0},
SbLocalVar{lambdaForArrayFrameId, 1}.toVar(),
fp,
level + 1,
state));
auto [unionStage, unionOutputSlots] = b.makeUnion(
std::move(terminalStagesToUnion),
std::vector{SbExpr::makeSV(terminalUnwindOutputSlot), SbExpr::makeSV(keyValueSlot)});
auto unionOutputSlot = unionOutputSlots[0];
auto [outStage, outSlots] = b.makeBranch(std::move(unionStage),
b.makeLimitOneCoScanTree(),
b.makeFunction("isArray", keyValueSlot),
SbExpr::makeSV(unionOutputSlot),
SbExpr::makeSV(keyValueSlot));
unionStage = std::move(outStage);
auto maybeUnionOutputSlot = outSlots[0];
currentStage = b.makeLoopJoin(std::move(currentStage),
std::move(unionStage),
{} /* outerProjects */,
SbExpr::makeSV(keyValueSlot) /* outerCorrelated */);
keyValueSlot = maybeUnionOutputSlot;
return {keyValueSlot, std::move(currentStage)};
// Generate the traverse stage for the current nested level. If the traversed field is an array,
// we know that traverseP will wrap the result into an array, that we need to remove by using
// unwindArray.
// For example, when processing
// {a: [{b:[1,[2,3]]}, {b:4}, {b:1}, {b:2}]}
// the result of getField("a") is an array, and traverseP will return an array
// [ [1,[2,3]], 4, 1, 2 ]
// holding the results of the lambda for each item; in order to obtain the list of leaf nodes we
// have to extract the content of the first item into the containing array, e.g.
// [ 1, [2,3], 4, 1, 2 ]
// When traverseP processes a non-array, the result could still be an array, but it would be the
// result of running the lambda on a non-array value, e.g.
// {a: {b:[1, [2]]} }
// The result would be [1, [2]] that is already in the correct form and should not be processed
// with unwindArray, or the result would be an incorrect [1, 2].
sbe::FrameId getFieldFrameId = state.frameId();
return b.makeLet(getFieldFrameId,
SbExpr::makeSeq(std::move(getFieldFromObject),
b.makeFunction("traverseP"_sd,
SbLocalVar{getFieldFrameId, 0},
std::move(lambdaForArrayExpr),
b.makeInt32Constant(1))),
b.makeIf(b.makeFunction("isArray"_sd, SbLocalVar{getFieldFrameId, 0}),
b.makeFunction("unwindArray"_sd, SbLocalVar{getFieldFrameId, 1}),
SbLocalVar{getFieldFrameId, 1}));
}
// Returns the vector of local slots to be used in lookup join, including the record slot and
@ -378,6 +340,76 @@ SbSlotVector buildLocalSlots(StageBuilderState& state, SbSlot localRecordSlot) {
return slots;
}
// We need to lookup the values in 'localKeyValueSet' in the index defined on the foreign
// collection. To do this, we need to generate set of point intervals corresponding to this
// value. Single value can correspond to multiple point intervals:
// - Array values:
// a. If array is empty, [Undefined, Undefined]
// b. If array is NOT empty, [array[0], array[0]] (point interval composed from the first
// array element). This is needed to match {_id: 0, a: [[1, 2]]} to {_id: 0, b: [1, 2]}.
// - All other types, including array itself as a value, single point interval [value, value].
// This is needed for arrays to match {_id: 1, a: [[1, 2]]} to {_id: 0, b: [[1, 2], 42]}.
//
// To implement these rules, we extract the first element of every array found in the set of
// local keys, append them to the set of local keys, then remove duplicates by converting it
// again to a set (the extracted value could be identical to another value already in the
// localKeyValueSet - SERVER-66119):
// localKeyValueSet = if(
// traverseF(localKeyValueSet, lambda(value){
// isArray(value)
// }, false),
// arrayToSet(concatArrays(localKeyValueSet,
// traverseP(localKeyValueSet, lambda(rawValue){
// if (isArray(rawValue)) then
// fillEmpty(
// getElement(rawValue, 0),
// Undefined
// )
// else
// rawValue
// }, 1)),
// localKeyValueSet
// )
//
// In case of non-array we add the same value, relying on the fact it will be removed by the
// deduplication
std::pair<SbSlot /* keyValuesSetSlot */, SbStage> buildKeySetForIndexScan(StageBuilderState& state,
SbStage inputStage,
SbSlot localKeysSetSlot,
const PlanNodeId nodeId) {
SbBuilder b(state, nodeId);
auto lambdaIsArrayFrameId = state.frameId();
SbLocalVar lambdaIsArrayVar(lambdaIsArrayFrameId, 0);
auto lambdaIsArrayExpr =
b.makeLocalLambda(lambdaIsArrayFrameId, b.makeFunction("isArray"_sd, lambdaIsArrayVar));
auto lambdaFrameId = state.frameId();
SbLocalVar lambdaVar(lambdaFrameId, 0);
auto lambdaExpr =
b.makeLocalLambda(lambdaFrameId,
b.makeIf(b.makeFunction("isArray"_sd, lambdaVar),
b.makeFillEmptyUndefined(b.makeFunction(
"getElement"_sd, lambdaVar, b.makeInt32Constant(0))),
lambdaVar));
SbExpr expr = b.makeIf(b.makeFunction("traverseF"_sd,
localKeysSetSlot,
std::move(lambdaIsArrayExpr),
b.makeBoolConstant(false)),
b.makeFunction("arrayToSet"_sd,
b.makeFunction("concatArrays"_sd,
localKeysSetSlot,
b.makeFunction("traverseP"_sd,
localKeysSetSlot,
std::move(lambdaExpr),
b.makeInt32Constant(1)))),
localKeysSetSlot);
auto [outStage, outSlots] = b.makeProject(std::move(inputStage), std::move(expr));
return {outSlots[0], std::move(outStage)};
}
// Creates stages for traversing path 'fp' in the record from 'inputSlot'. Puts the set of key
// values into 'keyValuesSetSlot'. For example, if the record in the 'inputSlot' is:
// {a: [{b:[1,[2,3]]}, {b:4}, {b:1}, {b:2}]},
@ -388,7 +420,7 @@ SbSlotVector buildLocalSlots(StageBuilderState& state, SbSlot localRecordSlot) {
std::pair<SbSlot /* keyValuesSetSlot */, SbStage> buildKeySetForLocal(
StageBuilderState& state,
SbStage inputStage,
SbSlot inputSlot,
const PlanStageSlots& slots,
const FieldPath& fp,
boost::optional<sbe::value::SlotId> collatorSlot,
const PlanNodeId nodeId) {
@ -406,7 +438,12 @@ std::pair<SbSlot /* keyValuesSetSlot */, SbStage> buildKeySetForLocal(
// value.
SbExpr expr = b.makeLet(
frameId,
SbExpr::makeSeq(generateLocalKeyStream(SbExpr{inputSlot}, fp, 0, state)),
SbExpr::makeSeq(generateLocalKeyStream(
slots.getResultObj(),
fp,
0,
state,
slots.getIfExists(std::make_pair(PlanStageSlots::kField, fp.getFieldName(0))))),
b.makeIf(b.makeFillEmptyFalse(b.makeFunction("isArray"_sd, SbLocalVar{frameId, 0})),
b.makeIf(b.makeFunction("isArrayEmpty"_sd, SbLocalVar{frameId, 0}),
b.makeConstant(arrayWithNullTag, arrayWithNullVal),
@ -420,7 +457,8 @@ std::pair<SbSlot /* keyValuesSetSlot */, SbStage> buildKeySetForLocal(
expr = b.makeFunction("arrayToSet"_sd, std::move(expr));
}
auto [outStage, outSlots] = b.makeProject(std::move(inputStage), std::move(expr));
auto [outStage, outSlots] =
b.makeProject(buildVariableTypes(slots), std::move(inputStage), std::move(expr));
inputStage = std::move(outStage);
auto keyValuesSetSlot = outSlots[0];
@ -430,61 +468,31 @@ std::pair<SbSlot /* keyValuesSetSlot */, SbStage> buildKeySetForLocal(
/**
* Traverses path 'fp' in the 'inputSlot' and puts the set of key values into 'keyValuesSetSlot'.
* Puts a stage that joins the original record with its set of keys into 'nljLocalWithKeyValuesSet'
*/
std::pair<SbSlot /* keyValuesSetSlot */, SbStage> buildKeySetForForeign(
StageBuilderState& state,
SbStage inputStage,
SbSlot inputSlot,
const FieldPath& fp,
SbSlot topLevelFieldSlot,
boost::optional<sbe::value::SlotId> collatorSlot,
const PlanNodeId nodeId) {
SbBuilder b(state, nodeId);
// Create the branch to stream individual key values from every terminal of the path.
auto [keyValueSlot, keyValuesStage] = buildForeignKeysStream(inputSlot, fp, nodeId, state);
SbExpr expr =
generateForeignKeyStream(inputSlot.toVar(), boost::none, fp, 0, state, topLevelFieldSlot);
// Convert the array into an ArraySet that has no duplicate keys.
if (collatorSlot) {
expr = b.makeFunction("collArrayToSet"_sd, SbSlot{*collatorSlot}, std::move(expr));
} else {
expr = b.makeFunction("arrayToSet"_sd, std::move(expr));
}
// Re-pack the individual key values into a set. We don't cap "addToSet" here because its
// size is bounded by the size of the record.
auto spillSlot = SbSlot{state.slotId()};
auto [outStage, outSlots] = b.makeProject(
buildVariableTypes(topLevelFieldSlot), std::move(inputStage), std::move(expr));
auto keyValuesSetSlot = outSlots[0];
auto addToSetExpr = collatorSlot
? b.makeFunction("collAddToSet"_sd, SbSlot{*collatorSlot}, keyValueSlot)
: b.makeFunction("addToSet"_sd, keyValueSlot);
auto aggSetUnionExpr = collatorSlot
? b.makeFunction("aggCollSetUnion"_sd, SbSlot{*collatorSlot}, spillSlot)
: b.makeFunction("aggSetUnion"_sd, spillSlot);
SbHashAggAccumulatorVector accumulatorList;
accumulatorList.emplace_back(SbHashAggAccumulator{
.outSlot = boost::none, // A slot will be assigned when creating the final HashAgg stage.
.spillSlot = spillSlot,
.implementation =
SbHashAggCompiledAccumulator{
.init = SbExpr{},
.agg = std::move(addToSetExpr),
.merge = std::move(aggSetUnionExpr),
},
});
auto [packedKeyValuesStage, _, aggOutSlots] = b.makeHashAgg(
VariableTypes{},
std::move(keyValuesStage),
{}, /* groupBy slots - an empty vector means creating a single group */
accumulatorList,
{} /* We group _all_ key values to a single set so we can ignore collation */);
SbSlot keyValuesSetSlot = aggOutSlots[0];
// Attach the set of key values to the original local record.
auto nljLocalWithKeyValuesSet =
b.makeLoopJoin(std::move(inputStage),
std::move(packedKeyValuesStage), // NOLINT(bugprone-use-after-move)
SbExpr::makeSV(inputSlot),
SbExpr::makeSV(inputSlot) /* outerCorrelated */);
return {keyValuesSetSlot, std::move(nljLocalWithKeyValuesSet)};
return {keyValuesSetSlot, std::move(outStage)};
}
// Creates stages for grouping matched foreign records into an array. If there's no match, the
@ -591,7 +599,7 @@ std::pair<SbSlot /* resultSlot */, SbStage> buildForeignMatchedArray(SbStage inn
* group [] [groupSlot = addToArrayCapped(foreignRecordSlot, 104857600)]
* filter {traverseF (
* let [
* l11.0 = fillEmpty (getField (foreignRecordSlot, "a"), null)
* l11.0 = fillEmpty (topLevelFieldSlot, null)
* ]
* in
* if typeMatch (l11.0, 24)
@ -604,14 +612,19 @@ std::pair<SbSlot /* resultSlot */, SbStage> buildForeignMatchedArray(SbStage inn
* true),
* else false
* }, false)}
* scan foreignRecordSlot recordIdSlot none none none none [] @uuid true false
* branch1 [emptySlot] project [emptySlot = []] limit 1 coscan
* scan foreignRecordSlot recordIdSlot none none none none [topLevelFieldSlot = "a"] @uuid true
* false
* branch1 [emptySlot]
* project [emptySlot = []]
* limit 1
* coscan
* ]
*/
std::pair<SbSlot /* matched docs */, SbStage> buildForeignMatches(SbSlot localKeySlot,
SbStage foreignStage,
SbSlot foreignRecordSlot,
const FieldPath& foreignFieldName,
SbSlot topLevelFieldSlot,
const PlanNodeId nodeId,
StageBuilderState& state,
bool hasUnwindSrc) {
@ -634,8 +647,11 @@ std::pair<SbSlot /* matched docs */, SbStage> buildForeignMatches(SbSlot localKe
frameId = state.frameId();
lambdaArg = i == 0 ? SbExpr{foreignRecordSlot} : SbExpr{SbVar{frameId, 0}};
auto getFieldOrNull = b.makeFillEmptyNull(b.makeFunction(
"getField"_sd, lambdaArg.clone(), b.makeStrConstant(foreignFieldName.getFieldName(i))));
auto getFieldOrNull = b.makeFillEmptyNull(
i == 0 ? topLevelFieldSlot
: b.makeFunction("getField"_sd,
lambdaArg.clone(),
b.makeStrConstant(foreignFieldName.getFieldName(i))));
// Non object/array field will be converted into Nothing, passing along recursive traverseF
// and will be treated as null to compared against local key set.
@ -669,7 +685,8 @@ std::pair<SbSlot /* matched docs */, SbStage> buildForeignMatches(SbSlot localKe
}
}
SbStage foreignOutputStage = b.makeFilter(std::move(foreignStage), std::move(filter));
SbStage foreignOutputStage = b.makeFilter(
buildVariableTypes(topLevelFieldSlot), std::move(foreignStage), std::move(filter));
if (hasUnwindSrc) {
// $LU [$lookup, $unwind] pattern: The query immediately unwinds the lookup result array. We
// implement this efficiently by returning a record for each individual foreign match one by
@ -687,7 +704,7 @@ std::pair<SbSlot /* matched docs */, SbStage> buildForeignMatches(SbSlot localKe
std::pair<SbSlot /* matched docs */, SbStage> buildNljLookupStage(
StageBuilderState& state,
SbStage localStage,
SbSlot localRecordSlot,
const PlanStageSlots& slots,
const FieldPath& localFieldName,
const CollectionPtr& foreignColl,
const FieldPath& foreignFieldName,
@ -701,10 +718,14 @@ std::pair<SbSlot /* matched docs */, SbStage> buildNljLookupStage(
// Build the outer branch that produces the set of local key values.
auto [localKeySlot, outerRootStage] = buildKeySetForLocal(
state, std::move(localStage), localRecordSlot, localFieldName, collatorSlot, nodeId);
state, std::move(localStage), slots, localFieldName, collatorSlot, nodeId);
auto [foreignStage, foreignRecordSlot, _, __] =
b.makeScan(foreignColl->uuid(), foreignColl->ns().dbName(), forwardScanDirection);
auto [foreignStage, foreignRecordSlot, _, scanFieldSlots] =
b.makeScan(foreignColl->uuid(),
foreignColl->ns().dbName(),
forwardScanDirection,
boost::none /* seekSlot */,
std::vector<std::string>{std::string(foreignFieldName.front())});
// Build the inner branch that will get the foreign key values, compare them to the local key
// values and accumulate all matching foreign records into an array that is placed into
@ -713,6 +734,7 @@ std::pair<SbSlot /* matched docs */, SbStage> buildNljLookupStage(
std::move(foreignStage),
foreignRecordSlot,
foreignFieldName,
scanFieldSlots[0],
nodeId,
state,
hasUnwindSrc);
@ -721,6 +743,7 @@ std::pair<SbSlot /* matched docs */, SbStage> buildNljLookupStage(
// it performs should not influence planning decisions made for 'outerRootStage'.
innerRootStage->disableTrialRunTracking();
SbSlot localRecordSlot = slots.getResultObj();
// Connect the two branches with a nested loop join. For each outer record with a corresponding
// value in the 'localKeySlot', the inner branch will be executed and will place the result into
// 'matchedRecordsSlot'.
@ -760,84 +783,28 @@ std::tuple<SbStage, SbSlot, SbSlot, SbSlotVector> buildIndexJoinLookupForeignSid
const auto indexVersion = indexAccessMethod->getSortedDataInterface()->getKeyStringVersion();
const auto indexOrdering = indexAccessMethod->getSortedDataInterface()->getOrdering();
// Unwind local keys one by one into 'singleLocalValueSlot'.
constexpr bool preserveNullAndEmptyArrays = true;
// Modify the set of values to lookup to include the first item of any array.
auto [localKeysIndexSetSlot, localKeysSetStage] =
buildKeySetForIndexScan(state, b.makeLimitOneCoScanTree(), localKeysSetSlot, nodeId);
auto [unwindLocalKeysStage, singleLocalValueSlot, _] =
b.makeUnwind(b.makeLimitOneCoScanTree(), localKeysSetSlot, preserveNullAndEmptyArrays);
// We need to lookup value in 'singleLocalValueSlot' in the index defined on the foreign
// collection. To do this, we need to generate set of point intervals corresponding to this
// value. Single value can correspond to multiple point intervals:
// - Array values:
// a. If array is empty, [Undefined, Undefined]
// b. If array is NOT empty, [array[0], array[0]] (point interval composed from the first
// array element). This is needed to match {_id: 0, a: [[1, 2]]} to {_id: 0, b: [1, 2]}.
// - All other types, including array itself as a value, single point interval [value, value].
// This is needed for arrays to match {_id: 1, a: [[1, 2]]} to {_id: 0, b: [[1, 2], 42]}.
//
// To implement these rules, we use the union stage:
// union pointValue [
// // Branch 1
// filter isArray(rawValue) && !isMember(pointValue, localKeyValueSet)
// project pointValue = fillEmpty(
// getElement(rawValue, 0),
// Undefined
// )
// limit 1
// coscan
// ,
// // Branch 2
// project pointValue = rawValue
// limit 1
// coscan
// ]
//
// For array values, branches (1) and (2) both produce values. For all other types, only (2)
// produces a value.
auto [arrayBranch, arrayBranchOutSlots] =
b.makeProject(b.makeLimitOneCoScanTree(),
b.makeFillEmptyUndefined(b.makeFunction(
"getElement", singleLocalValueSlot, b.makeInt32Constant(0))));
SbSlot arrayBranchOutput = arrayBranchOutSlots[0];
auto shouldProduceSeekForArray = b.makeBooleanOpTree(
abt::Operations::And,
b.makeFunction("isArray", singleLocalValueSlot),
b.makeNot(b.makeFunction("isMember", arrayBranchOutput, localKeysSetSlot)));
arrayBranch = b.makeFilter(std::move(arrayBranch), std::move(shouldProduceSeekForArray));
auto [valueBranch, valueBranchOutSlots] = b.makeProject(
b.makeLimitOneCoScanTree(), std::pair(SbExpr{singleLocalValueSlot}, state.slotId()));
SbSlot valueBranchOutput = valueBranchOutSlots[0];
auto unionInputs =
makeVector(SbExpr::makeSV(arrayBranchOutput), SbExpr::makeSV(valueBranchOutput));
auto [valueGeneratorStage, unionOutputSlots] = b.makeUnion(
sbe::makeSs(std::move(arrayBranch), std::move(valueBranch)), std::move(unionInputs));
auto valueForIndexBounds = unionOutputSlots[0];
// Unwind local keys one by one into 'valueForIndexBounds'.
auto [valueGeneratorStage, valueForIndexBounds, _] = b.makeUnwind(
std::move(localKeysSetStage), localKeysIndexSetSlot, true /*preserveNullAndEmptyArrays*/);
if (index.type == INDEX_HASHED) {
// For hashed indexes, we need to hash the value before computing keystrings iff the
// lookup's "foreignField" is the hashed field in this index.
const BSONElement elt = index.keyPattern.getField(foreignFieldName.fullPath());
if (elt.valueStringDataSafe() == IndexNames::HASHED) {
SbSlot rawValueSlot = valueForIndexBounds;
SbSlot indexValueSlot = rawValueSlot;
if (collatorSlot) {
// For collated hashed indexes, apply collation before hashing.
auto [outStage, outSlots] = b.makeProject(
std::move(valueGeneratorStage),
b.makeFunction("collComparisonKey", rawValueSlot, SbSlot{*collatorSlot}));
valueGeneratorStage = std::move(outStage);
indexValueSlot = outSlots[0];
}
auto [outStage, outSlots] = b.makeProject(std::move(valueGeneratorStage),
b.makeFunction("shardHash", indexValueSlot));
// For collated hashed indexes, apply collation before hashing.
auto [outStage, outSlots] =
b.makeProject(std::move(valueGeneratorStage),
b.makeFunction("shardHash"_sd,
collatorSlot ? b.makeFunction("collComparisonKey",
valueForIndexBounds,
SbSlot{*collatorSlot})
: valueForIndexBounds));
valueGeneratorStage = std::move(outStage);
valueForIndexBounds = outSlots[0];
}
@ -870,14 +837,6 @@ std::tuple<SbStage, SbSlot, SbSlot, SbSlotVector> buildIndexJoinLookupForeignSid
SbSlot lowKeySlot = outSlots[0];
SbSlot highKeySlot = outSlots[1];
// To ensure that we compute index bounds for all local values, introduce loop join, where
// unwinding of local values happens on the right side and index generation happens on the left
// side.
indexBoundKeyStage = b.makeLoopJoin(std::move(unwindLocalKeysStage),
std::move(indexBoundKeyStage),
{} /* outerProjects */,
SbExpr::makeSV(singleLocalValueSlot) /* outerCorrelated */);
auto indexInfoTypeMask = SbIndexInfoType::kIndexIdent | SbIndexInfoType::kIndexKey |
SbIndexInfoType::kIndexKeyPattern | SbIndexInfoType::kSnapshotId;
@ -930,7 +889,7 @@ std::tuple<SbStage, SbSlot, SbSlot, SbSlotVector> buildIndexJoinLookupForeignSid
// 'indexKeySlot' and 'indexKeyPatternSlot' to perform index consistency check during the seek.
auto [scanNljStage, scanNljValueSlot, scanNljRecordIdSlot, scanNljFieldSlots] =
makeLoopJoinForFetch(std::move(ixScanNljStage),
std::vector<std::string>{},
std::vector<std::string>{std::string(foreignFieldName.front())},
foreignRecordIdSlot,
snapshotIdSlot,
indexIdentSlot,
@ -954,45 +913,26 @@ std::tuple<SbStage, SbSlot, SbSlot, SbSlotVector> buildIndexJoinLookupForeignSid
* collection. Note that parts reading the local values and constructing the resulting document are
* omitted.
*
* nlj [foreignDocument] [foreignDocument]
* left
* filter {isMember (foreignValue, localValueSet)}
* nlj
* left
* nlj [lowKey, highKey]
* left
* nlj
* left
* unwind localKeySet localValue
* limit 1
* coscan
* right
* project lowKey = ks (1, 0, valueForIndexBounds, 1),
* highKey = ks (1, 0, valueForIndexBounds, 2)
* union [valueForIndexBounds] [
* cfilter {isArray (localValue)}
* project [valueForIndexBounds = fillEmpty (getElement (localValue, 0), undefined)]
* limit 1
* coscan
* ,
* project [valueForIndexBounds = localValue]
* limit 1
* coscan
* ]
* project lowKey = ks (1, 0, valueForIndexBounds, 1),
* highKey = ks (1, 0, valueForIndexBounds, 2)
* unwind localKeySet localValue
* limit 1
* coscan
* right
* ixseek lowKey highKey recordId @"b_1"
* right
* limit 1
* seek s21 foreignDocument recordId @"foreign collection"
* right
* limit 1
* filter {isMember (foreignValue, localValueSet)}
* // Below is the tree performing path traversal on the 'foreignDocument' and producing value
* // into 'foreignValue'.
* seek s21 foreignDocument recordId [foreignValue = "b"] @"foreign collection"
*/
std::pair<SbSlot, SbStage> buildIndexJoinLookupStage(
StageBuilderState& state,
SbStage localStage,
SbSlot localRecordSlot,
const PlanStageSlots& slots,
const FieldPath& localFieldName,
const FieldPath& foreignFieldName,
const CollectionPtr& foreignColl,
@ -1006,10 +946,10 @@ std::pair<SbSlot, SbStage> buildIndexJoinLookupStage(
// Build the outer branch that produces the correlated local key slot.
auto [localKeysSetSlot, localKeysSetStage] = buildKeySetForLocal(
state, std::move(localStage), localRecordSlot, localFieldName, collatorSlot, nodeId);
state, std::move(localStage), slots, localFieldName, collatorSlot, nodeId);
// Build the inner branch that produces the correlated foreign key slot.
auto [scanNljStage, foreignRecordSlot, _, __] =
auto [scanNljStage, foreignRecordSlot, _, scanFieldSlots] =
buildIndexJoinLookupForeignSideStage(state,
localKeysSetSlot,
localFieldName,
@ -1028,6 +968,7 @@ std::pair<SbSlot, SbStage> buildIndexJoinLookupStage(
std::move(scanNljStage),
foreignRecordSlot,
foreignFieldName,
scanFieldSlots[0],
nodeId,
state,
hasUnwindSrc);
@ -1036,6 +977,7 @@ std::pair<SbSlot, SbStage> buildIndexJoinLookupStage(
// that it performs should not influence planning decisions for 'localKeysSetStage'.
foreignGroupStage->disableTrialRunTracking();
SbSlot localRecordSlot = slots.getResultObj();
// The top level loop join stage that joins each local field with the matched foreign
// documents.
auto nljStage = b.makeLoopJoin(std::move(localKeysSetStage),
@ -1063,23 +1005,11 @@ std::pair<SbSlot, SbStage> buildIndexJoinLookupStage(
* left
* nlj [lowKey, highKey]
* left
* nlj
* left
* unwind localKeySet localValue
* limit 1
* coscan
* right
* project lowKey = ks (1, 0, valueForIndexBounds, 1),
* highKey = ks (1, 0, valueForIndexBounds, 2)
* union [valueForIndexBounds] [
* cfilter {isArray (localValue)}
* project [valueForIndexBounds = fillEmpty (getElement (localValue, 0),
* undefined)] limit 1 coscan
* ,
* project [valueForIndexBounds = localValue]
* limit 1
* coscan
* ]
* project lowKey = ks (1, 0, valueForIndexBounds, 1),
* highKey = ks (1, 0, valueForIndexBounds, 2)
* unwind localKeySet localValue
* limit 1
* coscan
* right
* ixseek lowKey highKey recordId @"b_1"
* right
@ -1091,7 +1021,7 @@ std::pair<SbSlot, SbStage> buildIndexJoinLookupStage(
std::pair<SbSlot, SbStage> buildDynamicIndexedLoopJoinLookupStage(
StageBuilderState& state,
SbStage localStage,
SbSlot localRecordSlot,
const PlanStageSlots& slots,
const FieldPath& localFieldName,
const CollectionPtr& foreignColl,
const FieldPath& foreignFieldName,
@ -1107,21 +1037,31 @@ std::pair<SbSlot, SbStage> buildDynamicIndexedLoopJoinLookupStage(
// Build the index Lookup branch
auto [localKeysSetSlot, localKeysSetStage] = buildKeySetForLocal(
state, std::move(localStage), localRecordSlot, localFieldName, collatorSlot, nodeId);
auto [indexLookupBranchStage, indexLookupBranchResultSlot, indexLookupBranchRecordIdSlot, _] =
buildIndexJoinLookupForeignSideStage(state,
localKeysSetSlot,
localFieldName,
foreignFieldName,
foreignColl,
index,
collatorSlot,
nodeId,
hasUnwindSrc);
state, std::move(localStage), slots, localFieldName, collatorSlot, nodeId);
auto [indexLookupBranchStage,
indexLookupBranchResultSlot,
indexLookupBranchRecordIdSlot,
indexLookupBranchScanSlots] = buildIndexJoinLookupForeignSideStage(state,
localKeysSetSlot,
localFieldName,
foreignFieldName,
foreignColl,
index,
collatorSlot,
nodeId,
hasUnwindSrc);
// Build the nested loop branch.
auto [nestedLoopBranchStage, nestedLoopBranchResultSlot, nestedLoopBranchRecordIdSlot, __] =
b.makeScan(foreignColl->uuid(), foreignColl->ns().dbName(), forwardScanDirection);
auto [nestedLoopBranchStage,
nestedLoopBranchResultSlot,
nestedLoopBranchRecordIdSlot,
nestedLoopBranchScanSlots] =
b.makeScan(foreignColl->uuid(),
foreignColl->ns().dbName(),
forwardScanDirection,
boost::none /* seekSlot */,
std::vector<std::string>{std::string(foreignFieldName.front())});
// Build the typeMatch filter expression
sbe::FrameId frameId = state.frameId();
@ -1138,18 +1078,22 @@ std::pair<SbSlot, SbStage> buildDynamicIndexedLoopJoinLookupStage(
b.makeBoolConstant(false) /*compareArray*/));
// Create a branch stage
auto [branchStage, branchSlots] =
b.makeBranch(std::move(indexLookupBranchStage),
std::move(nestedLoopBranchStage),
std::move(filter),
SbExpr::makeSV(indexLookupBranchResultSlot, indexLookupBranchRecordIdSlot),
SbExpr::makeSV(nestedLoopBranchResultSlot, nestedLoopBranchRecordIdSlot));
auto [branchStage, branchSlots] = b.makeBranch(std::move(indexLookupBranchStage),
std::move(nestedLoopBranchStage),
std::move(filter),
SbExpr::makeSV(indexLookupBranchResultSlot,
indexLookupBranchRecordIdSlot,
indexLookupBranchScanSlots[0]),
SbExpr::makeSV(nestedLoopBranchResultSlot,
nestedLoopBranchRecordIdSlot,
nestedLoopBranchScanSlots[0]));
SbSlot resultSlot = branchSlots[0];
auto [finalForeignSlot, finalForeignStage] = buildForeignMatches(localKeysSetSlot,
std::move(branchStage),
resultSlot,
foreignFieldName,
branchSlots[2],
nodeId,
state,
hasUnwindSrc);
@ -1158,6 +1102,7 @@ std::pair<SbSlot, SbStage> buildDynamicIndexedLoopJoinLookupStage(
// reads that it performs should not influence planning decisions for 'outerRootStage'.
finalForeignStage->disableTrialRunTracking();
SbSlot localRecordSlot = slots.getResultObj();
// Connect the local (left) and foreign (right) sides with a nested loop join. For each left
// record with a corresponding value in the 'localKeySlot', the right branch will be executed
// and will place the result into 'matchedRecordsSlot'.
@ -1174,7 +1119,7 @@ std::pair<SbSlot, SbStage> buildDynamicIndexedLoopJoinLookupStage(
std::pair<SbSlot /*matched docs*/, SbStage> buildHashJoinLookupStage(
StageBuilderState& state,
SbStage localStage,
SbSlot localRecordSlot,
const PlanStageSlots& slots,
const FieldPath& localFieldName,
const CollectionPtr& foreignColl,
const FieldPath& foreignFieldName,
@ -1188,15 +1133,24 @@ std::pair<SbSlot /*matched docs*/, SbStage> buildHashJoinLookupStage(
// Build the outer branch that produces the correlated local key slot.
auto [localKeysSetSlot, localKeysSetStage] = buildKeySetForLocal(
state, std::move(localStage), localRecordSlot, localFieldName, collatorSlot, nodeId);
state, std::move(localStage), slots, localFieldName, collatorSlot, nodeId);
// Build the inner branch that produces the set of foreign key values.
auto [foreignStage, foreignRecordSlot, foreignRecordIdSlot, _] =
b.makeScan(foreignColl->uuid(), foreignColl->ns().dbName(), forwardScanDirection);
auto [foreignStage, foreignRecordSlot, foreignRecordIdSlot, scanFieldSlots] =
b.makeScan(foreignColl->uuid(),
foreignColl->ns().dbName(),
forwardScanDirection,
boost::none /* seekSlot */,
std::vector<std::string>{std::string(foreignFieldName.front())});
auto [foreignKeySlot, foreignKeyStage] = buildKeySetForForeign(
state, std::move(foreignStage), foreignRecordSlot, foreignFieldName, collatorSlot, nodeId);
auto [foreignKeySlot, foreignKeyStage] = buildKeySetForForeign(state,
std::move(foreignStage),
foreignRecordSlot,
foreignFieldName,
scanFieldSlots[0],
collatorSlot,
nodeId);
// 'foreignKeyStage' should not participate in trial run tracking as the number of
// reads that it performs should not influence planning decisions for 'outerRootStage'.
@ -1263,7 +1217,7 @@ std::pair<SbSlot /*matched docs*/, SbStage> buildLookupStage(
StageBuilderState& state,
EqLookupNode::LookupStrategy lookupStrategy,
SbStage localStage,
SbSlot localRecordSlot,
const PlanStageSlots& slots,
const FieldPath& localFieldName,
const FieldPath& foreignFieldName,
const CollectionPtr& foreignColl,
@ -1291,7 +1245,7 @@ std::pair<SbSlot /*matched docs*/, SbStage> buildLookupStage(
return buildIndexJoinLookupStage(state,
std::move(localStage),
localRecordSlot,
slots,
localFieldName,
foreignFieldName,
foreignColl,
@ -1308,7 +1262,7 @@ std::pair<SbSlot /*matched docs*/, SbStage> buildLookupStage(
return buildDynamicIndexedLoopJoinLookupStage(state,
std::move(localStage),
localRecordSlot,
slots,
localFieldName,
foreignColl,
foreignFieldName,
@ -1323,7 +1277,7 @@ std::pair<SbSlot /*matched docs*/, SbStage> buildLookupStage(
return buildNljLookupStage(state,
std::move(localStage),
localRecordSlot,
slots,
localFieldName,
foreignColl,
foreignFieldName,
@ -1337,7 +1291,7 @@ std::pair<SbSlot /*matched docs*/, SbStage> buildLookupStage(
return buildHashJoinLookupStage(state,
std::move(localStage),
localRecordSlot,
slots,
localFieldName,
foreignColl,
foreignFieldName,
@ -1400,6 +1354,9 @@ std::pair<SbStage, PlanStageSlots> SlotBasedStageBuilder::buildEqLookup(
}
PlanStageReqs childReqs = reqs.copyForChild().setResultObj();
// Try to get the beginning of the localField path into a slot.
childReqs.setFields(
std::vector<std::string>{std::string(eqLookupNode->joinFieldLocal.front())});
auto [localStage, localOutputs] = build(eqLookupNode->children[0].get(), childReqs);
SbSlot localRecordSlot = localOutputs.getResultObj();
@ -1418,7 +1375,7 @@ std::pair<SbStage, PlanStageSlots> SlotBasedStageBuilder::buildEqLookup(
buildLookupStage(_state,
eqLookupNode->lookupStrategy,
std::move(localStage),
localRecordSlot,
localOutputs,
eqLookupNode->joinFieldLocal,
eqLookupNode->joinFieldForeign,
foreignColl,
@ -1477,7 +1434,7 @@ std::pair<SbStage, PlanStageSlots> SlotBasedStageBuilder::buildEqLookupUnwind(
buildLookupStage(_state,
eqLookupUnwindNode->lookupStrategy,
std::move(localStage),
localRecordSlot,
localOutputs,
eqLookupUnwindNode->joinFieldLocal,
eqLookupUnwindNode->joinFieldForeign,
foreignColl,

View File

@ -268,6 +268,12 @@ SbExpr SbExprBuilder::makeLocalLambda(sbe::FrameId frameId, SbExpr expr) {
extractABT(expr));
}
SbExpr SbExprBuilder::makeLocalLambda2(sbe::FrameId frameId, SbExpr expr) {
return abt::make<abt::LambdaAbstraction>(SbVar(frameId, 0).toProjectionName(),
SbVar(frameId, 1).toProjectionName(),
extractABT(expr));
}
SbExpr SbExprBuilder::makeNumericConvert(SbExpr expr, sbe::value::TypeTags tag) {
return makeFunction(
"convert"_sd, std::move(expr), makeInt32Constant(static_cast<int32_t>(tag)));

View File

@ -179,6 +179,7 @@ public:
SbExpr makeLet(sbe::FrameId frameId, SbExpr::Vector binds, SbExpr expr);
SbExpr makeLocalLambda(sbe::FrameId frameId, SbExpr expr);
SbExpr makeLocalLambda2(sbe::FrameId frameId, SbExpr expr);
SbExpr makeNumericConvert(SbExpr expr, sbe::value::TypeTags tag);

View File

@ -257,10 +257,6 @@ inline auto _lambda(StringData pn, ExprHolder body) {
return ExprHolder{make<LambdaAbstraction>(ProjectionName{pn}, std::move(body._n))};
}
inline auto _lambdaApp(ExprHolder lambda, ExprHolder arg) {
return ExprHolder{make<LambdaApplication>(std::move(lambda._n), std::move(arg._n))};
}
template <typename... Ts>
inline auto _fn(StringData name, Ts&&... pack) {
std::vector<ExprHolder> v;

View File

@ -165,6 +165,138 @@ TEST_F(AbtToSbeExpression, Lower4) {
sbe::value::ValueGuard guard(resultTag, resultVal);
ASSERT_EQ(sbe::value::TypeTags::Array, resultTag);
auto arrResult = sbe::value::getArrayView(resultVal);
ASSERT_EQ(4, arrResult->size());
auto arrResult0 = arrResult->values()[0];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult0.first);
ASSERT_EQ(11, sbe::value::bitcastTo<int64_t>(arrResult0.second));
auto arrResult1 = arrResult->values()[1];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult1.first);
ASSERT_EQ(12, sbe::value::bitcastTo<int64_t>(arrResult1.second));
auto arrResult2 = arrResult->values()[2];
ASSERT_EQ(sbe::value::TypeTags::Array, arrResult2.first);
auto arrResult2v = sbe::value::getArrayView(arrResult2.second);
ASSERT_EQ(2, arrResult2v->size());
auto arrResult2v0 = arrResult2v->values()[0];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult2v0.first);
ASSERT_EQ(31, sbe::value::bitcastTo<int64_t>(arrResult2v0.second));
auto arrResult2v1 = arrResult2v->values()[1];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult2v1.first);
ASSERT_EQ(32, sbe::value::bitcastTo<int64_t>(arrResult2v1.second));
auto arrResult3 = arrResult->values()[3];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult3.first);
ASSERT_EQ(13, sbe::value::bitcastTo<int64_t>(arrResult3.second));
}
TEST_F(AbtToSbeExpression, Lower4TwoArgsOneLevel) {
auto [tagArr, valArr] = sbe::value::makeNewArray();
auto arr = sbe::value::getArrayView(valArr);
arr->push_back(sbe::value::TypeTags::NumberInt64, 1);
arr->push_back(sbe::value::TypeTags::NumberInt64, 2);
auto [tagArrNest, valArrNest] = sbe::value::makeNewArray();
auto arrNest = sbe::value::getArrayView(valArrNest);
arrNest->push_back(sbe::value::TypeTags::NumberInt64, 21);
arrNest->push_back(sbe::value::TypeTags::NumberInt64, 22);
arr->push_back(tagArrNest, valArrNest);
arr->push_back(sbe::value::TypeTags::NumberInt64, 3);
auto tree = make<FunctionCall>(
"traverseP",
makeSeq(
make<Constant>(tagArr, valArr),
make<LambdaAbstraction>(
"value",
"index",
make<If>(
// Comparing the index with a full int64 implies we only process the first
// level.
make<BinaryOp>(Operations::Eq, make<Variable>("index"), Constant::int64(1)),
make<BinaryOp>(Operations::Add, make<Variable>("value"), Constant::int64(10)),
Constant::nothing())),
Constant::nothing()));
auto env = VariableEnvironment::build(tree);
SlotVarMap map;
sbe::InputParamToSlotMap inputParamToSlotMap;
auto expr =
SBEExpressionLowering{env, map, *runtimeEnv(), slotIdGenerator(), inputParamToSlotMap}
.optimize(tree);
ASSERT(expr);
auto compiledExpr = compileExpression(*expr);
auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
sbe::value::ValueGuard guard(resultTag, resultVal);
ASSERT_EQ(sbe::value::TypeTags::Array, resultTag);
auto arrResult = sbe::value::getArrayView(resultVal);
ASSERT_EQ(2, arrResult->size());
auto arrResult0 = arrResult->values()[0];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult0.first);
ASSERT_EQ(12, sbe::value::bitcastTo<int64_t>(arrResult0.second));
auto arrResult1 = arrResult->values()[1];
ASSERT_EQ(sbe::value::TypeTags::Array, arrResult1.first);
ASSERT_EQ(0, sbe::value::getArrayView(arrResult1.second)->size());
}
TEST_F(AbtToSbeExpression, Lower4TwoArgsAnyLevel) {
auto [tagArr, valArr] = sbe::value::makeNewArray();
auto arr = sbe::value::getArrayView(valArr);
arr->push_back(sbe::value::TypeTags::NumberInt64, 1);
arr->push_back(sbe::value::TypeTags::NumberInt64, 2);
auto [tagArrNest, valArrNest] = sbe::value::makeNewArray();
auto arrNest = sbe::value::getArrayView(valArrNest);
arrNest->push_back(sbe::value::TypeTags::NumberInt64, 21);
arrNest->push_back(sbe::value::TypeTags::NumberInt64, 22);
arr->push_back(tagArrNest, valArrNest);
arr->push_back(sbe::value::TypeTags::NumberInt64, 3);
auto tree = make<FunctionCall>(
"traverseP",
makeSeq(
make<Constant>(tagArr, valArr),
make<LambdaAbstraction>(
"value",
"index",
make<If>(
// Trimming the index to the lowest 32 bits makes the test work on any level.
make<BinaryOp>(
Operations::Eq,
make<FunctionCall>(
"mod", makeSeq(make<Variable>("index"), Constant::int64(1LL << 32))),
Constant::int64(1)),
make<BinaryOp>(Operations::Add, make<Variable>("value"), Constant::int64(10)),
Constant::nothing())),
Constant::nothing()));
auto env = VariableEnvironment::build(tree);
SlotVarMap map;
sbe::InputParamToSlotMap inputParamToSlotMap;
auto expr =
SBEExpressionLowering{env, map, *runtimeEnv(), slotIdGenerator(), inputParamToSlotMap}
.optimize(tree);
ASSERT(expr);
auto compiledExpr = compileExpression(*expr);
auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
sbe::value::ValueGuard guard(resultTag, resultVal);
ASSERT_EQ(sbe::value::TypeTags::Array, resultTag);
auto arrResult = sbe::value::getArrayView(resultVal);
ASSERT_EQ(2, arrResult->size());
auto arrResult0 = arrResult->values()[0];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult0.first);
ASSERT_EQ(12, sbe::value::bitcastTo<int64_t>(arrResult0.second));
auto arrResult1 = arrResult->values()[1];
ASSERT_EQ(sbe::value::TypeTags::Array, arrResult1.first);
auto arrResult1v = sbe::value::getArrayView(arrResult1.second);
ASSERT_EQ(1, arrResult1v->size());
auto arrResult1v0 = arrResult1v->values()[0];
ASSERT_EQ(sbe::value::TypeTags::NumberInt64, arrResult1v0.first);
ASSERT_EQ(32, sbe::value::bitcastTo<int64_t>(arrResult1v0.second));
}
TEST_F(AbtToSbeExpression, Lower5) {

View File

@ -82,7 +82,7 @@ TEST(ValueLifetimeTest, ProcessTraverseOnGlobal) {
" }, \n"
" {\n"
" nodeType: \"LambdaAbstraction\", \n"
" variable: \"y\", \n"
" variable0: \"y\", \n"
" input: {\n"
" nodeType: \"FunctionCall\", \n"
" name: \"getField\", \n"
@ -107,7 +107,7 @@ TEST(ValueLifetimeTest, ProcessTraverseOnGlobal) {
" }, \n"
" {\n"
" nodeType: \"LambdaAbstraction\", \n"
" variable: \"x\", \n"
" variable0: \"x\", \n"
" input: {\n"
" nodeType: \"FunctionCall\", \n"
" name: \"getField\", \n"
@ -181,7 +181,7 @@ TEST(ValueLifetimeTest, ProcessTraverseOnLocal) {
" }, \n"
" {\n"
" nodeType: \"LambdaAbstraction\", \n"
" variable: \"y\", \n"
" variable0: \"y\", \n"
" input: {\n"
" nodeType: \"FunctionCall\", \n"
" name: \"getField\", \n"
@ -206,7 +206,7 @@ TEST(ValueLifetimeTest, ProcessTraverseOnLocal) {
" }, \n"
" {\n"
" nodeType: \"LambdaAbstraction\", \n"
" variable: \"x\", \n"
" variable0: \"x\", \n"
" input: {\n"
" nodeType: \"FunctionCall\", \n"
" name: \"getField\", \n"
@ -306,7 +306,7 @@ TEST(ValueLifetimeTest, ProcessTraverseOnLocalVariable) {
" }, \n"
" {\n"
" nodeType: \"LambdaAbstraction\", \n"
" variable: \"y\", \n"
" variable0: \"y\", \n"
" input: {\n"
" nodeType: \"FunctionCall\", \n"
" name: \"getField\", \n"
@ -331,7 +331,7 @@ TEST(ValueLifetimeTest, ProcessTraverseOnLocalVariable) {
" }, \n"
" {\n"
" nodeType: \"LambdaAbstraction\", \n"
" variable: \"x\", \n"
" variable0: \"x\", \n"
" input: {\n"
" nodeType: \"FunctionCall\", \n"
" name: \"getField\", \n"
@ -442,7 +442,7 @@ TEST(ValueLifetimeTest, ProcessTraverseOnFillEmpty) {
" }, \n"
" {\n"
" nodeType: \"LambdaAbstraction\", \n"
" variable: \"x\", \n"
" variable0: \"x\", \n"
" input: {\n"
" nodeType: \"FunctionCall\", \n"
" name: \"getField\", \n"

View File

@ -420,32 +420,56 @@ TypeSignature TypeChecker::operator()(abt::ABT& n, abt::FunctionCall& op, bool s
TypeSignature argType = op.nodes()[0].visit(*this, false);
auto lambda = op.nodes()[1].cast<abt::LambdaAbstraction>();
// A traverseF/traverseP invoked with the first argument that is not an array will just
// invoke the lambda expression on it, so we can remove it if we are assured that it
// cannot possibly contain an array.
if (!argType.containsAny(TypeSignature::kArrayType)) {
auto lambda = op.nodes()[1].cast<abt::LambdaAbstraction>();
// Define the lambda variable with the type of the 'bind' expression type.
bind(lambda->varName(), argType);
bind(lambda->varNames()[0], argType);
if (lambda->varNames().size() > 1) {
bind(lambda->varNames()[1], TypeSignature::kNumericType);
}
// Process the lambda knowing that its argument will be exactly the type we got from
// processing the first argument.
TypeSignature lambdaType = op.nodes()[1].visit(*this, false);
// The current binding must be the one where we defined the variable.
invariant(_bindings.back().contains(lambda->varName()));
_bindings.back().erase(lambda->varName());
invariant(_bindings.back().contains(lambda->varNames()[0]));
_bindings.back().erase(lambda->varNames()[0]);
if (lambda->varNames().size() == 1) {
swapAndUpdate(n,
abt::make<abt::Let>(
lambda->varNames()[0],
std::exchange(op.nodes()[0], abt::make<abt::Blackhole>()),
std::exchange(lambda->getBody(), abt::make<abt::Blackhole>())));
} else {
_bindings.back().erase(lambda->varNames()[1]);
swapAndUpdate(
n,
abt::make<abt::Let>(
lambda->varNames()[1],
abt::Constant::int64(-1),
abt::make<abt::Let>(
lambda->varNames()[0],
std::exchange(op.nodes()[0], abt::make<abt::Blackhole>()),
std::exchange(lambda->getBody(), abt::make<abt::Blackhole>()))));
}
swapAndUpdate(
n,
abt::make<abt::Let>(lambda->varName(),
std::exchange(op.nodes()[0], abt::make<abt::Blackhole>()),
std::exchange(lambda->getBody(), abt::make<abt::Blackhole>())));
return lambdaType.include(argType.intersect(TypeSignature::kNothingType));
}
// The first argument could be an array, so the lambda will be invoked on multiple array
// items of unknown type.
if (lambda->varNames().size() > 1) {
bind(lambda->varNames()[1], TypeSignature::kNumericType);
}
op.nodes()[1].visit(*this, false);
if (lambda->varNames().size() > 1) {
_bindings.back().erase(lambda->varNames()[1]);
}
// Nothing can be inferred about the return type of traverseF()/traverseP() in this case.
return TypeSignature::kAnyScalarType;
}

View File

@ -143,12 +143,18 @@ ValueLifetime::ValueType ValueLifetime::operator()(abt::ABT& n, abt::FunctionCal
auto lambda = op.nodes()[1].cast<abt::LambdaAbstraction>();
// Define the lambda variable with the type of the 'bind' expression type.
_bindings[lambda->varName()] = argType;
_bindings[lambda->varNames()[0]] = argType;
if (lambda->varNames().size() > 1) {
_bindings[lambda->varNames()[1]] = ValueType::GlobalValue;
}
// Process the lambda knowing that its argument will be exactly the type we got from
// processing the first argument.
ValueType lambdaType = op.nodes()[1].visit(*this);
_bindings.erase(lambda->varName());
_bindings.erase(lambda->varNames()[0]);
if (lambda->varNames().size() > 1) {
_bindings.erase(lambda->varNames()[1]);
}
// If the first argument is an array, the result is always a local value (array of cloned
// results for traverseP, a boolean value for traverseF). If it is not an array, then the

View File

@ -585,38 +585,40 @@ Vectorizer::Tree Vectorizer::operator()(const abt::ABT& n, const abt::FunctionCa
if (TypeSignature::kBlockType.isSubset(argument.typeSignature) &&
argument.sourceCell.has_value()) {
const abt::LambdaAbstraction* lambda = op.nodes()[1].cast<abt::LambdaAbstraction>();
// Reuse the variable name of the lambda so that we don't have to manipulate the code
// inside the lambda (and to avoid problems if the expression we are going to iterate
// over has side effects and the lambda references it multiple times, as replacing it
// directly in code would imply executing more than once). Don't propagate the reference
// to the cell slot, as we are going to fold the result and we don't want that the
// lambda does it too.
_variableTypes.insert_or_assign(lambda->varName(),
std::make_pair(argument.typeSignature, boost::none));
auto lambdaArg = lambda->getBody().visit(*this);
_variableTypes.erase(lambda->varName());
if (!lambdaArg.expr.has_value()) {
return lambdaArg;
if (lambda->varNames().size() == 1) {
// Reuse the variable name of the lambda so that we don't have to manipulate the
// code inside the lambda (and to avoid problems if the expression we are going to
// iterate over has side effects and the lambda references it multiple times, as
// replacing it directly in code would imply executing more than once). Don't
// propagate the reference to the cell slot, as we are going to fold the result and
// we don't want that the lambda does it too.
_variableTypes.insert_or_assign(
lambda->varNames()[0], std::make_pair(argument.typeSignature, boost::none));
auto lambdaArg = lambda->getBody().visit(*this);
_variableTypes.erase(lambda->varNames()[0]);
if (!lambdaArg.expr.has_value()) {
return lambdaArg;
}
// If the body of the lambda is just a scalar constant, create a block
// of the same size of the block argument filled with that value.
if (!TypeSignature::kBlockType.isSubset(lambdaArg.typeSignature)) {
lambdaArg.expr = makeABTFunction(
"valueBlockNewFill"_sd,
std::move(*lambdaArg.expr),
makeABTFunction("valueBlockSize"_sd, makeVariable(lambda->varNames()[0])));
lambdaArg.typeSignature =
TypeSignature::kBlockType.include(lambdaArg.typeSignature);
lambdaArg.sourceCell = boost::none;
}
return {makeLet(lambda->varNames()[0],
std::move(*argument.expr),
makeABTFunction("cellFoldValues_F"_sd,
std::move(*lambdaArg.expr),
makeVariable(*argument.sourceCell))),
TypeSignature::kBlockType.include(TypeSignature::kBooleanType)
.include(argument.typeSignature.intersect(TypeSignature::kNothingType)),
{}};
}
// If the body of the lambda is just a scalar constant, create a block
// of the same size of the block argument filled with that value.
if (!TypeSignature::kBlockType.isSubset(lambdaArg.typeSignature)) {
lambdaArg.expr = makeABTFunction(
"valueBlockNewFill"_sd,
std::move(*lambdaArg.expr),
makeABTFunction("valueBlockSize"_sd, makeVariable(lambda->varName())));
lambdaArg.typeSignature =
TypeSignature::kBlockType.include(lambdaArg.typeSignature);
lambdaArg.sourceCell = boost::none;
}
return {makeLet(lambda->varName(),
std::move(*argument.expr),
makeABTFunction("cellFoldValues_F"_sd,
std::move(*lambdaArg.expr),
makeVariable(*argument.sourceCell))),
TypeSignature::kBlockType.include(TypeSignature::kBooleanType)
.include(argument.typeSignature.intersect(TypeSignature::kNothingType)),
{}};
}
}

View File

@ -3,14 +3,14 @@
traverseF(s1, lambda(l10.0) { (l10.0 == 3) }, Nothing)
-- COMPILED EXPRESSION:
[0x0000-0x002b] stackSize: 1, maxStackSize: 1
[0x0000-0x002c] stackSize: 1, maxStackSize: 1
0x0000: jmp(target: 0x001c);
0x0005: allocStack(size:1);
0x000a: pushConstVal(value: 3);
0x0014: eq(popLhs: 0, moveFromLhs: 0, offsetLhs: 0, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x001b: ret();
0x001c: pushAccessVal(accessor: <accessor>);
0x0025: traverseFImm(k: False, target: 0x0005);
0x0025: traverseFImm(providePosition: False, k: False, target: 0x0005);
-- EXECUTE VARIATION:

View File

@ -0,0 +1,25 @@
# Golden test output of SBELambdaTest/TraverseF_OpEqFirstArrayItem
-- INPUT EXPRESSION:
traverseF(s1, lambda(l10.0 l10.1) { ((l10.1 == 0ll) && (l10.0 == 3)) }, Nothing)
-- COMPILED EXPRESSION:
[0x0000-0x0056] stackSize: 1, maxStackSize: 1
0x0000: jmp(target: 0x0046);
0x0005: allocStack(size:1);
0x000a: pushConstVal(value: 0ll);
0x0014: eq(popLhs: 0, moveFromLhs: 0, offsetLhs: 0, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x001b: jmpNothing(target: 0x0045);
0x0020: jmpFalse(target: 0x003b);
0x0025: pushConstVal(value: 3);
0x002f: eq(popLhs: 0, moveFromLhs: 0, offsetLhs: 1, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x0036: jmp(target: 0x0045);
0x003b: pushConstVal(value: false);
0x0045: ret();
0x0046: pushAccessVal(accessor: <accessor>);
0x004f: traverseFImm(providePosition: True, k: False, target: 0x0005);
-- EXECUTE VARIATION:
SLOTS: [1: [1, 2, 3, 4]]
RESULT: false

View File

@ -13,7 +13,7 @@
-- COMPILED EXPRESSION:
[0x0000-0x0063] stackSize: 1, maxStackSize: 4
[0x0000-0x0064] stackSize: 1, maxStackSize: 4
0x0000: pushAccessVal(accessor: <accessor>);
0x0009: pushConstVal(value: 10);
0x0013: pushConstVal(value: 20);
@ -23,18 +23,18 @@
0x0031: eq(popLhs: 0, moveFromLhs: 0, offsetLhs: 0, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x0038: ret();
0x0039: pushLocalVal(arg: 2);
0x003e: traverseFImm(k: False, target: 0x0022);
0x0044: jmpNothing(target: 0x005d);
0x0049: jmpTrue(target: 0x0058);
0x004e: pushLocalVal(arg: 0);
0x0053: jmp(target: 0x005d);
0x0058: pushLocalVal(arg: 1);
0x005d: swap();
0x005e: pop();
0x005f: swap();
0x0060: pop();
0x0061: swap();
0x0062: pop();
0x003e: traverseFImm(providePosition: False, k: False, target: 0x0022);
0x0045: jmpNothing(target: 0x005e);
0x004a: jmpTrue(target: 0x0059);
0x004f: pushLocalVal(arg: 0);
0x0054: jmp(target: 0x005e);
0x0059: pushLocalVal(arg: 1);
0x005e: swap();
0x005f: pop();
0x0060: swap();
0x0061: pop();
0x0062: swap();
0x0063: pop();
-- EXECUTE VARIATION:

View File

@ -3,14 +3,14 @@
traverseP(s1, lambda(l10.0) { (l10.0 + 1) }, Nothing)
-- COMPILED EXPRESSION:
[0x0000-0x002b] stackSize: 1, maxStackSize: 1
[0x0000-0x002c] stackSize: 1, maxStackSize: 1
0x0000: jmp(target: 0x001c);
0x0005: allocStack(size:1);
0x000a: pushConstVal(value: 1);
0x0014: add(popLhs: 0, moveFromLhs: 0, offsetLhs: 0, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x001b: ret();
0x001c: pushAccessVal(accessor: <accessor>);
0x0025: traversePImm(k: Nothing, target: 0x0005);
0x0025: traversePImm(providePosition: False, k: Nothing, target: 0x0005);
-- EXECUTE VARIATION:

View File

@ -0,0 +1,29 @@
# Golden test output of SBELambdaTest/TraverseP_AddOneToFirstArrayItem
-- INPUT EXPRESSION:
traverseP(s1, lambda(l10.0 l10.1) {
if (l10.1 == 0ll)
then (l10.0 + 1)
else l10.0
}, Nothing)
-- COMPILED EXPRESSION:
[0x0000-0x0051] stackSize: 1, maxStackSize: 1
0x0000: jmp(target: 0x0041);
0x0005: allocStack(size:1);
0x000a: pushConstVal(value: 0ll);
0x0014: eq(popLhs: 0, moveFromLhs: 0, offsetLhs: 0, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x001b: jmpNothing(target: 0x0040);
0x0020: jmpTrue(target: 0x002f);
0x0025: pushLocalVal(arg: 1);
0x002a: jmp(target: 0x0040);
0x002f: pushConstVal(value: 1);
0x0039: add(popLhs: 0, moveFromLhs: 0, offsetLhs: 1, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x0040: ret();
0x0041: pushAccessVal(accessor: <accessor>);
0x004a: traversePImm(providePosition: True, k: Nothing, target: 0x0005);
-- EXECUTE VARIATION:
SLOTS: [1: [1, 2, 3]]
RESULT: [2, 2, 3]

View File

@ -1,12 +1,12 @@
# Golden test output of SBEVM/CodeFragmentPrintStable
[0x0000-0x0048] stackSize: -1, maxStackSize: 0
[0x0000-0x004b] stackSize: -1, maxStackSize: 0
0x0000: fillEmptyImm(k: Null);
0x0002: fillEmptyImm(k: False);
0x0004: fillEmptyImm(k: True);
0x0006: traversePImm(k: Nothing, target: 0x00aa);
0x000c: traversePImm(k: 1, target: 0x00aa);
0x0012: traverseFImm(k: True, target: 0x00bb);
0x0018: getFieldImm(popParam: 1, moveFromParam: 1, offsetParam: 0, value: "Hello world!");
0x0027: add(popLhs: 1, moveFromLhs: 1, offsetLhs: 0, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x002a: dateTruncImm(unit: 4, binSize: 1, timezone: TimeZone(name=America/New_York), startOfWeek: 1);
0x0006: traversePImm(providePosition: False, k: Nothing, target: 0x00aa);
0x000d: traversePImm(providePosition: False, k: 1, target: 0x00aa);
0x0014: traverseFImm(providePosition: False, k: True, target: 0x00bb);
0x001b: getFieldImm(popParam: 1, moveFromParam: 1, offsetParam: 0, value: "Hello world!");
0x002a: add(popLhs: 1, moveFromLhs: 1, offsetLhs: 0, popRhs: 1, moveFromRhs: 1, offsetRhs: 0);
0x002d: dateTruncImm(unit: 4, binSize: 1, timezone: TimeZone(name=America/New_York), startOfWeek: 1);