From 9813c7766924d67ff901b4c877e179e72aa996d0 Mon Sep 17 00:00:00 2001 From: Max Verbinnen <64088654+Max-Verbinnen@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:36:02 +0100 Subject: [PATCH] SERVER-107790 Design workload for SORT_MERGE (#41084) GitOrigin-RevId: ab90fe3618437be7a0f4861c8fb7e5647f8c68e0 --- .../cost_model/calibration_settings.py | 63 ++++++++++++++++++- buildscripts/cost_model/cost_estimator.py | 1 + .../cost_model/execution_tree_classic.py | 15 ++++- .../parameters_extractor_classic.py | 2 + .../cost_model/qsn_costing_parameters.py | 1 + buildscripts/cost_model/start.py | 22 +++++++ 6 files changed, 98 insertions(+), 6 deletions(-) diff --git a/buildscripts/cost_model/calibration_settings.py b/buildscripts/cost_model/calibration_settings.py index adefdf01ffc..337b34db221 100644 --- a/buildscripts/cost_model/calibration_settings.py +++ b/buildscripts/cost_model/calibration_settings.py @@ -258,6 +258,40 @@ def create_coll_scan_collection_template( return template +def create_merge_sort_collection_template( + name: str, cardinalities: list[int], num_merge_fields: int = 10 +) -> config.CollectionTemplate: + # Generate fields "a", "b", ... "j" (if num_merge_fields is 10) + field_names = [chr(ord("a") + i) for i in range(num_merge_fields)] + fields = [ + config.FieldTemplate( + name=field_name, + data_type=config.DataType.INTEGER, + distribution=RandomDistribution.uniform( + RangeGenerator(DataType.INTEGER, 1, num_merge_fields + 1) + ), + indexed=False, + ) + for field_name in field_names + ] + fields.append( + config.FieldTemplate( + name="sort_field", + data_type=config.DataType.STRING, + distribution=random_strings_distr(10, 1000), + indexed=False, + ) + ) + compound_indexes = [{field_name: 1, "sort_field": 1} for field_name in field_names] + + return config.CollectionTemplate( + name=name, + fields=fields, + compound_indexes=compound_indexes, + cardinalities=cardinalities, + ) + + collection_caridinalities = list(range(10000, 50001, 10000)) c_int_05 = config.CollectionTemplate( @@ -321,13 +355,25 @@ sort_collections = create_coll_scan_collection_template( cardinalities=[5, 10, 50, 75, 100, 150, 300, 400, 500, 750, 1000], payload_size=10, ) +merge_sort_collections = create_merge_sort_collection_template( + "merge_sort", + cardinalities=[5, 10, 50, 75, 100, 150, 300, 400, 500, 750, 1000], + num_merge_fields=10, +) # Data Generator settings data_generator = config.DataGeneratorConfig( enabled=True, create_indexes=True, batch_size=10000, - collection_templates=[index_scan, coll_scan, sort_collections, c_int_05, c_arr_01], + collection_templates=[ + index_scan, + coll_scan, + sort_collections, + merge_sort_collections, + c_int_05, + c_arr_01, + ], write_mode=config.WriteMode.REPLACE, collection_name_with_card=True, ) @@ -373,8 +419,19 @@ qsn_nodes = [ config.QsNodeCalibrationConfig(type="AND_HASH"), config.QsNodeCalibrationConfig(type="AND_SORTED"), config.QsNodeCalibrationConfig(type="OR"), - config.QsNodeCalibrationConfig(type="MERGE_SORT"), - config.QsNodeCalibrationConfig(type="SORT_MERGE"), + config.QsNodeCalibrationConfig( + type="SORT_MERGE", + # Note: n_returned = n_processed - (amount of duplicates dropped) + variables_override=lambda df: pd.concat( + [ + (df["n_returned"] * np.log2(df["n_input_stages"])).rename( + "n_returned * log2(n_input_stages)" + ), + df["n_processed"], + ], + axis=1, + ), + ), config.QsNodeCalibrationConfig( name="SORT_DEFAULT", type="SORT", diff --git a/buildscripts/cost_model/cost_estimator.py b/buildscripts/cost_model/cost_estimator.py index 048e2daf724..e0f5407bc70 100644 --- a/buildscripts/cost_model/cost_estimator.py +++ b/buildscripts/cost_model/cost_estimator.py @@ -45,6 +45,7 @@ class ExecutionStats: execution_time: int n_returned: int n_processed: int + n_input_stages: int seeks: Optional[int] diff --git a/buildscripts/cost_model/execution_tree_classic.py b/buildscripts/cost_model/execution_tree_classic.py index 5dd0bac51c3..aef2fdb6dc2 100644 --- a/buildscripts/cost_model/execution_tree_classic.py +++ b/buildscripts/cost_model/execution_tree_classic.py @@ -42,6 +42,7 @@ class Node: execution_time_nanoseconds: int n_returned: int n_processed: int + n_input_stages: int seeks: Optional[int] children: list[Node] @@ -54,7 +55,7 @@ class Node: def print(self, level=0): """Pretty print the execution tree""" print( - f'{"| " * level}{self.stage}, totalExecutionTime: {self.execution_time_nanoseconds:,}ns, seeks: {self.seeks}, nReturned: {self.n_returned}, nProcessed: {self.n_processed}' + f'{"| " * level}{self.stage}, totalExecutionTime: {self.execution_time_nanoseconds:,}ns, seeks: {self.seeks}, nReturned: {self.n_returned}, nProcessed: {self.n_processed}, nInputStages: {self.n_input_stages}' ) for child in self.children: child.print(level + 1) @@ -76,7 +77,6 @@ def process_stage(stage: dict[str, Any]) -> Node: "AND_HASH": process_intersection, "AND_SORTED": process_intersection, "OR": process_or, - "MERGE_SORT": process_mergesort, "SORT_MERGE": process_mergesort, "SORT": process_sort, "LIMIT": process_passthrough, @@ -134,7 +134,13 @@ def process_intersection(stage: dict[str, Any]) -> Node: def process_mergesort(stage: dict[str, Any]) -> Node: children = [process_stage(child) for child in stage["inputStages"]] - return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=children) + # The number of processed documents is not just `stage["nReturned"]`, because that does + # not include the potential duplicate documents which may had to be processed and dropped. + return Node( + **get_common_fields(stage), + n_processed=sum(child.n_returned for child in children), + children=children, + ) def process_skip(stage: dict[str, Any]) -> Node: @@ -151,5 +157,8 @@ def get_common_fields(json_stage: dict[str, Any]) -> dict[str, Any]: "stage": json_stage["stage"], "execution_time_nanoseconds": json_stage["executionTimeNanos"], "n_returned": json_stage["nReturned"], + "n_input_stages": 1 + if "inputStage" in json_stage + else len(json_stage.get("inputStages", [])), "seeks": json_stage.get("seeks"), } diff --git a/buildscripts/cost_model/parameters_extractor_classic.py b/buildscripts/cost_model/parameters_extractor_classic.py index 61ef5bb2346..4e9d75478d0 100644 --- a/buildscripts/cost_model/parameters_extractor_classic.py +++ b/buildscripts/cost_model/parameters_extractor_classic.py @@ -96,6 +96,8 @@ def get_execution_stats( execution_time=enode.get_execution_time(), n_returned=enode.n_returned, n_processed=enode.n_processed, + # This will be 0 in case there are no input stages + n_input_stages=enode.n_input_stages, # Seeks will be None for any node but IXSCAN. seeks=enode.seeks, ) diff --git a/buildscripts/cost_model/qsn_costing_parameters.py b/buildscripts/cost_model/qsn_costing_parameters.py index 87ef9134449..be7c7c4af71 100644 --- a/buildscripts/cost_model/qsn_costing_parameters.py +++ b/buildscripts/cost_model/qsn_costing_parameters.py @@ -99,6 +99,7 @@ class ParametersBuilderClassic: "stage", "execution_time", "n_processed", + "n_input_stages", "seeks", "note", "keys_length_in_bytes", diff --git a/buildscripts/cost_model/start.py b/buildscripts/cost_model/start.py index cd13371248b..f8148f1ab96 100644 --- a/buildscripts/cost_model/start.py +++ b/buildscripts/cost_model/start.py @@ -250,6 +250,27 @@ async def execute_sorts(database: DatabaseInstance, collections: Sequence[Collec ) +async def execute_merge_sorts(database: DatabaseInstance, collections: Sequence[CollectionInfo]): + collections = [c for c in collections if c.name.startswith("merge_sort")] + fields = collections[0].fields + + requests = [] + for num_merge_inputs in range(2, len(fields)): + requests.append( + Query( + find_cmd={ + "filter": {"$or": [{f.name: 1} for f in fields[:num_merge_inputs]]}, + "sort": {"sort_field": 1}, + }, + note="SORT_MERGE", + ) + ) + + await workload_execution.execute( + database, main_config.workload_execution, collections, requests + ) + + async def main(): """Entry point function.""" script_directory = os.path.abspath(os.path.dirname(__file__)) @@ -270,6 +291,7 @@ async def main(): execute_limits, execute_skips, execute_sorts, + execute_merge_sorts, ] for execute_query in execution_query_functions: await execute_query(database, generator.collection_infos)