diff --git a/buildscripts/idl/idl/ast.py b/buildscripts/idl/idl/ast.py index cc3dbe0ef36..efd2eb2ea4f 100644 --- a/buildscripts/idl/idl/ast.py +++ b/buildscripts/idl/idl/ast.py @@ -272,6 +272,9 @@ class Field(common.SourceLocation): # Properties specific to fields inlined from chained_structs self.chained_struct_field = None # type: Field + # If this field is a nested chained struct, add the parent field which this field is chained from. + self.nested_chained_parent = None # type: Field + # Internal fields - not generated by parser self.serialize_op_msg_request_only = False # type: bool self.constructed = False # type: bool diff --git a/buildscripts/idl/idl/binder.py b/buildscripts/idl/idl/binder.py index ee3c4df4e88..d82cd10123c 100644 --- a/buildscripts/idl/idl/binder.py +++ b/buildscripts/idl/idl/binder.py @@ -1269,9 +1269,8 @@ def _bind_field(ctxt, parsed_spec, field): ctxt.add_must_be_query_shape_component(ast_field, ast_field.type.name, ast_field.name) return ast_field - -def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct): - # type: (errors.ParserContext, syntax.IDLSpec, ast.Struct, syntax.ChainedStruct) -> None +def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct, nested_chained_parent=None): + # type: (errors.ParserContext, syntax.IDLSpec, ast.Struct, syntax.ChainedStruct, ast.Field) -> None """Bind the specified chained struct.""" syntax_symbol = parsed_spec.symbols.resolve_type_from_name( ctxt, ast_struct, chained_struct.name, chained_struct.name @@ -1292,10 +1291,6 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct): ast_struct, ast_struct.name, chained_struct.name ) - if struct.chained_structs: - ctxt.add_chained_nested_struct_no_nested_error( - ast_struct, ast_struct.name, chained_struct.name - ) # Configure a field for the chained struct. ast_chained_field = ast.Field(ast_struct.file_name, ast_struct.line, ast_struct.column) @@ -1305,6 +1300,13 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct): ast_chained_field.description = struct.description ast_chained_field.chained = True + if struct.chained_structs: + for nested_chained_struct in struct.chained_structs or []: + _bind_chained_struct(ctxt, parsed_spec, ast_struct, nested_chained_struct, ast_chained_field) + + if nested_chained_parent: + ast_chained_field.nested_chained_parent = nested_chained_parent + if not _is_duplicate_field(ctxt, chained_struct.name, ast_struct.fields, ast_chained_field): ast_struct.fields.append(ast_chained_field) else: diff --git a/buildscripts/idl/idl/errors.py b/buildscripts/idl/idl/errors.py index 5c3da9168ce..3e112369b8d 100644 --- a/buildscripts/idl/idl/errors.py +++ b/buildscripts/idl/idl/errors.py @@ -71,7 +71,6 @@ ERROR_ID_BAD_BINDATA_DEFAULT = "ID0026" ERROR_ID_CHAINED_DUPLICATE_FIELD = "ID0029" ERROR_ID_CHAINED_STRUCT_NOT_FOUND = "ID0031" ERROR_ID_CHAINED_NO_NESTED_STRUCT_STRICT = "ID0032" -ERROR_ID_CHAINED_NO_NESTED_CHAINED = "ID0033" ERROR_ID_BAD_EMPTY_ENUM = "ID0034" ERROR_ID_NO_ARRAY_ENUM = "ID0035" ERROR_ID_ENUM_BAD_TYPE = "ID0036" @@ -666,19 +665,6 @@ class ParserContext(object): % (nested_struct_name, struct_name), ) - def add_chained_nested_struct_no_nested_error(self, location, struct_name, chained_name): - # type: (common.SourceLocation, str, str) -> None - """Add an error about struct's chaining being a struct with nested chaining.""" - self._add_error( - location, - ERROR_ID_CHAINED_NO_NESTED_CHAINED, - ( - "Struct '%s' is not allowed to nest struct '%s' since it has chained" - + " structs and/or types." - ) - % (struct_name, chained_name), - ) - def add_empty_enum_error(self, node, name): # type: (yaml.nodes.Node, str) -> None """Add an error about an enum without values.""" diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py index c07a11d3031..18c30cef9b0 100644 --- a/buildscripts/idl/idl/generator.py +++ b/buildscripts/idl/idl/generator.py @@ -701,12 +701,15 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Generate a getter that disables xvalue for view types (i.e. StringData), constructed # optional types, and non-primitive types. if field.chained_struct_field: - member_name = _get_field_member_name(field.chained_struct_field) + chained_struct_getter = _get_field_member_getter_name(field.chained_struct_field) self._writer.write_line( - f"{const_type}{param_type} {method_name}() const {{ return {member_name}.{method_name}(); }}" + f"{const_type}{param_type} {method_name}() const {{ return {chained_struct_getter}().{method_name}(); }}" ) elif field.type.is_struct: + if field.nested_chained_parent: + body = f"return {_get_field_member_getter_name(field.nested_chained_parent)}().{_get_field_member_getter_name(field)}();" + # Support mutable accessors self._writer.write_line(f"const {param_type} {method_name}() const {{ {body} }}") @@ -750,11 +753,12 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Generate the setter for instances of the "getter/setter type", which may not be the same # as the storage type. if field.chained_struct_field: - body = "{}.{}(std::move(value));".format( - _get_field_member_name(field.chained_struct_field), memfn - ) + body = f"{_get_field_member_getter_name(field.chained_struct_field)}().{memfn}(std::move(value));" else: - body = cpp_type_info.get_setter_body(_get_field_member_name(field), validator) + if field.nested_chained_parent: + body = f"{_get_field_member_getter_name(field.nested_chained_parent)}().{memfn}(std::move(value));" + else: + body = cpp_type_info.get_setter_body(_get_field_member_name(field), validator) set_has = _gen_mark_present(field.cpp_name) if is_serial else "" with self._block(f"void {memfn}({setter_type} value) {{", "}"): @@ -1427,7 +1431,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Write member variables for field in struct.fields: - if not field.ignore and not field.chained_struct_field: + if not field.ignore and not field.chained_struct_field and not field.nested_chained_parent: if not (field.type and field.type.internal_only): self.gen_member(field) @@ -1888,15 +1892,14 @@ class _CppSourceFileWriter(_CppFileWriterBase): def validate_and_assign_or_uassert(field, expression): # type: (ast.Field, str) -> None """Perform field value validation post-assignment.""" - field_name = _get_field_member_name(field) if field.validator is None: - self._writer.write_line("%s = %s;" % (field_name, expression)) + self._writer.write_line("%s = %s;" % (_get_field_member_name(field), expression)) return with self._block("{", "}"): self._writer.write_line("auto value = %s;" % (expression)) self._writer.write_line("%s(value);" % (_get_field_member_validator_name(field))) - self._writer.write_line("%s = std::move(value);" % (field_name)) + self._writer.write_line("%s = std::move(value);" % (_get_field_member_name(field))) if field.chained: # Do not generate a predicate check since we always call these deserializers. @@ -1912,7 +1915,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): expression = "%s(%s)" % (method_name, bson_object) self._gen_usage_check(field, bson_element, field_usage_check) - validate_and_assign_or_uassert(field, expression) + + if not field.nested_chained_parent: + validate_and_assign_or_uassert(field, expression) else: predicate = None @@ -1937,9 +1942,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): # No need for explicit validation as setter will throw for us. self._writer.write_line( - "%s.%s(%s);" + "%s().%s(%s);" % ( - _get_field_member_name(field.chained_struct_field), + _get_field_member_getter_name(field.chained_struct_field), _get_field_member_setter_name(field), object_value, ) @@ -2148,7 +2153,6 @@ class _CppSourceFileWriter(_CppFileWriterBase): required_constructor = struct_type_info.get_required_constructor_method() if len(required_constructor.args) != len(constructor.args): - # print(struct.name + ": "+ str(required_constructor.args)) self._gen_constructor(struct, required_constructor, False) def gen_field_list_entry_lookup_methods_struct(self, struct): @@ -2907,6 +2911,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): if field.chained_struct_field: continue + + if field.nested_chained_parent: + continue # The $db injected field should only be inject when serializing to OpMsgRequest. In the # BSON case, it will be injected in the generic command layer. diff --git a/buildscripts/idl/tests/test_binder.py b/buildscripts/idl/tests/test_binder.py index 37ac88fe0f5..327ad62aeac 100644 --- a/buildscripts/idl/tests/test_binder.py +++ b/buildscripts/idl/tests/test_binder.py @@ -1572,6 +1572,30 @@ class TestBinder(testcase.IDLTestcase): ) ) + # Chained struct with nested chained struct + self.assert_bind( + test_preamble + + indent_text( + 1, + textwrap.dedent(""" + bar1: + description: foo + strict: false + chained_structs: + chained: alias + + foobar: + description: foo + strict: false + chained_structs: + bar1: alias + fields: + f1: string + + """), + ) + ) + def test_chained_struct_negative(self): # type: () -> None """Negative parser chaining test cases.""" @@ -1701,31 +1725,6 @@ class TestBinder(testcase.IDLTestcase): idl.errors.ERROR_ID_CHAINED_NO_NESTED_STRUCT_STRICT, ) - # Chained struct with nested chained struct - self.assert_bind_fail( - test_preamble - + indent_text( - 1, - textwrap.dedent(""" - bar1: - description: foo - strict: false - chained_structs: - chained: alias - - foobar: - description: foo - strict: false - chained_structs: - bar1: alias - fields: - f1: string - - """), - ), - idl.errors.ERROR_ID_CHAINED_NO_NESTED_CHAINED, - ) - def test_enum_positive(self): # type: () -> None """Positive enum test cases.""" diff --git a/buildscripts/idl/tests/test_generator.py b/buildscripts/idl/tests/test_generator.py index fc2cda7c76c..f4f71348cca 100644 --- a/buildscripts/idl/tests/test_generator.py +++ b/buildscripts/idl/tests/test_generator.py @@ -1012,10 +1012,97 @@ class TestGenerator(testcase.IDLTestcase): header, [ "mongo::VariantStruct _variantStruct;", - "const std::variant& getField1() const { return _variantStruct.getField1(); }", + "const std::variant& getField1() const { return getVariantStruct().getField1(); }", ], ) + def test_nested_chained_structs(self) -> None: + """Test that nested chained structs generate the right setters and getters.""" + header, source = self.assert_generate_with_basic_types( + dedent( + """ + structs: + NestedChainedBase: + description: "Base struct for testing nested chains" + fields: + base_field: int + NestedChainedBottom: + description: "Bottom struct for nested chaining" + chained_structs: + NestedChainedBase: NestedChainedBase + fields: + bottom_field: int + NestedChainedMiddle: + description: "Middle struct for nested chaining" + chained_structs: + NestedChainedBottom: NestedChainedBottom + fields: + middle_field: string + NestedChainedTop: + description: "Top struct for nested chaining" + chained_structs: + NestedChainedMiddle: NestedChainedMiddle + fields: + top_field: bool + """ + ) + ) + self.assertStringsInFile(header, [ + "mongo::NestedChainedBase& getNestedChainedBase() { return getNestedChainedBottom().getNestedChainedBase();", + "void setNestedChainedBase(mongo::NestedChainedBase value) {\n getNestedChainedBottom().setNestedChainedBase(std::move(value));", + "void setBase_field(std::int32_t value) {\n getNestedChainedBase().setBase_field(std::move(value));", + "mongo::NestedChainedBottom& getNestedChainedBottom() { return getNestedChainedMiddle().getNestedChainedBottom();", + "void setNestedChainedBottom(mongo::NestedChainedBottom value) {\n getNestedChainedMiddle().setNestedChainedBottom(std::move(value));", + ]) + self.assertStringsInFile(source, ["getNestedChainedBase().setBase_field(element._numberInt());", + "getNestedChainedBottom().setBottom_field(element._numberInt());", + "getNestedChainedMiddle().setMiddle_field(element.str());", + "_top_field = element.boolean();", + ]) + + header, source = self.assert_generate_with_basic_types( + dedent( + """ + structs: + NestedChainedNoInlineBase: + description: "Base struct for testing nested chains without inline" + strict: false + fields: + base_field: int + NestedChainedNoInlineBottom: + description: "Top struct for nested chaining without inline" + inline_chained_structs: false + strict: false + chained_structs: + NestedChainedNoInlineBase: NestedChainedNoInlineBase + fields: + bottom_field: string + NestedChainedNoInlineTop: + description: "Top struct for nested chaining without inline" + strict: false + inline_chained_structs: false + chained_structs: + NestedChainedNoInlineBottom: NestedChainedNoInlineBottom + fields: + top_field: bool + """ + ) + ) + self.assertStringsInFile( + header, + [ + "mongo::NestedChainedNoInlineBase& getNestedChainedNoInlineBase() { return getNestedChainedNoInlineBottom().getNestedChainedNoInlineBase();", + "void setNestedChainedNoInlineBase(mongo::NestedChainedNoInlineBase value) {\n getNestedChainedNoInlineBottom().setNestedChainedNoInlineBase(std::move(value));", + "mongo::NestedChainedNoInlineBottom& getNestedChainedNoInlineBottom() { return _nestedChainedNoInlineBottom;", + ], + ) + + # Inline setters/getters not generated. + self.assertStringNotInFile( + header, + "void setBase_field(std::int32_t value) {\n getNestedChainedNoInlineBase().setBase_field(std::move(value));", + ) + def test_callback_validators(self) -> None: """Test generation of validators with the 'callback:' tag.""" _, source = self.assert_generate_with_basic_types( diff --git a/docs/idl.md b/docs/idl.md index ec7c530e3e0..d45710e03fc 100644 --- a/docs/idl.md +++ b/docs/idl.md @@ -716,6 +716,9 @@ the struct including them. This means that instead of users have to call `obj.getChainedStruct.getCommonField()`, they can call `obj.getCommonField()` instead. Field storage is not affected as this option is only syntactic sugar. +There can be multiple levels of chained structs. Be wary of circular chaining when choosing to use +multi level chained structs. + ### Struct Reference - `description` - string - A comment to add to the generated C++ diff --git a/src/mongo/idl/idl_test.cpp b/src/mongo/idl/idl_test.cpp index 7c49696528f..d3a4b4e5455 100644 --- a/src/mongo/idl/idl_test.cpp +++ b/src/mongo/idl/idl_test.cpp @@ -5357,6 +5357,53 @@ TEST(IDLTrie, TestPrefixes) { ASSERT_TRUE(TestTrieArgs::hasField("swimmed")); } +TEST(IDLNestedChaining, Parse) { + auto testDoc = BSON("base_field" << 42 << "bottom_field" << 40 << "middle_field" << "hello" + << "top_field" << true); + auto topStruct = NestedChainedTop::parse(testDoc); + + ASSERT_EQUALS(topStruct.getBase_field(), 42); + ASSERT_EQUALS(topStruct.getBottom_field(), 40); + ASSERT_EQUALS(topStruct.getMiddle_field(), "hello"); + ASSERT_EQUALS(topStruct.getTop_field(), true); + + // Test various methods generated from `inlined_chained_structs: true`. + ASSERT_EQUALS(topStruct.getNestedChainedMiddle().getNestedChainedBase().getBase_field(), 42); + ASSERT_EQUALS(topStruct.getNestedChainedBottom().getBottom_field(), 40); + + BSONObj serialized = topStruct.toBSON(); + ASSERT_BSONOBJ_EQ(serialized, testDoc); +} + +TEST(IDLNestedChaining, Initialize) { + auto testDoc = BSON("base_field" << 42 << "bottom_field" << 40 << "middle_field" << "hello" + << "top_field" << true); + + NestedChainedTop newStruct; + newStruct.getNestedChainedMiddle() + .getNestedChainedBottom() + .getNestedChainedBase() + .setBase_field(42); + newStruct.setBottom_field(40); + newStruct.getNestedChainedMiddle().setMiddle_field("hello"); + newStruct.setTop_field(true); + + BSONObj newSerialized = newStruct.toBSON(); + ASSERT_BSONOBJ_EQ(newSerialized, testDoc); +} + +TEST(IDLNestedChaining, NoInline) { + auto testDoc = BSON("base_field" << 42 << "bottom_field" << "hello" << "top_field" << true); + auto topStruct = NestedChainedNoInlineTop::parse(testDoc); + + ASSERT_EQUALS(topStruct.getNestedChainedNoInlineBottom().getBottom_field(), "hello"); + ASSERT_EQUALS(topStruct.getNestedChainedNoInlineBase().getBase_field(), 42); + ASSERT_EQUALS(topStruct.getTop_field(), true); + + BSONObj serialized = topStruct.toBSON(); + ASSERT_BSONOBJ_EQ(serialized, testDoc); +} + template void testBasicTypeSerialization(StringData fieldName, ParseValueType value) { // Positive: parse correct type. diff --git a/src/mongo/idl/unittest.idl b/src/mongo/idl/unittest.idl index 9417683bd61..47e6abb2fab 100644 --- a/src/mongo/idl/unittest.idl +++ b/src/mongo/idl/unittest.idl @@ -462,6 +462,61 @@ structs: fields: value: safeInt64 + ################################################################################################## + # + # Test structs with tested chaining + # + ################################################################################################## + NestedChainedBase: + description: "Base struct for testing nested chains" + fields: + base_field: int + + NestedChainedBottom: + description: "Bottom struct for nested chaining" + chained_structs: + NestedChainedBase: NestedChainedBase + fields: + bottom_field: int + + NestedChainedMiddle: + description: "Middle struct for nested chaining" + chained_structs: + NestedChainedBottom: NestedChainedBottom + fields: + middle_field: string + + NestedChainedTop: + description: "Top struct for nested chaining" + chained_structs: + NestedChainedMiddle: NestedChainedMiddle + fields: + top_field: bool + + NestedChainedNoInlineBase: + description: "Base struct for testing nested chains without inline" + strict: false + fields: + base_field: int + + NestedChainedNoInlineBottom: + description: "Top struct for nested chaining without inline" + inline_chained_structs: false + strict: false + chained_structs: + NestedChainedNoInlineBase: NestedChainedNoInlineBase + fields: + bottom_field: string + + NestedChainedNoInlineTop: + description: "Top struct for nested chaining without inline" + strict: false + inline_chained_structs: false + chained_structs: + NestedChainedNoInlineBottom: NestedChainedNoInlineBottom + fields: + top_field: bool + ################################################################################################## # # Test fields with default values