mirror of https://github.com/mongodb/mongo
187 lines
7.7 KiB
Python
187 lines
7.7 KiB
Python
# Copyright (C) 2022-present MongoDB, Inc.
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the Server Side Public License, version 1,
|
|
# as published by MongoDB, Inc.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# Server Side Public License for more details.
|
|
#
|
|
# You should have received a copy of the Server Side Public License
|
|
# along with this program. If not, see
|
|
# <http://www.mongodb.com/licensing/server-side-public-license>.
|
|
#
|
|
# As a special exception, the copyright holders give permission to link the
|
|
# code of portions of this program with the OpenSSL library under certain
|
|
# conditions as described in each individual source file and distribute
|
|
# linked combinations including the program with the OpenSSL library. You
|
|
# must comply with the Server Side Public License in all respects for
|
|
# all of the code used other than as permitted herein. If you modify file(s)
|
|
# with this exception, you may extend this exception to your version of the
|
|
# file(s), but you are not obligated to do so. If you do not wish to do so,
|
|
# delete this exception statement from your version. If you delete this
|
|
# exception statement from all source files in the program, then also delete
|
|
# it in the license file.
|
|
#
|
|
"""A wrapper with useful methods over MongoDB database."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import subprocess
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, Mapping, NewType, Sequence
|
|
|
|
from config import DatabaseConfig, RestoreMode
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
|
|
__all__ = ["DatabaseInstance", "Pipeline"]
|
|
"""MongoDB Aggregate's Pipeline"""
|
|
Pipeline = NewType("Pipeline", Sequence[Mapping[str, Any]])
|
|
|
|
|
|
class DatabaseInstance:
|
|
"""MongoDB Database wrapper."""
|
|
|
|
def __init__(self, config: DatabaseConfig) -> None:
|
|
"""Initialize wrapper."""
|
|
self.config = config
|
|
self.client = AsyncIOMotorClient(config.connection_string)
|
|
self.database = self.client[config.database_name]
|
|
|
|
def __enter__(self):
|
|
if self.config.restore_from_dump == RestoreMode.ALWAYS or (
|
|
self.config.restore_from_dump == RestoreMode.ONLY_NEW
|
|
and self.config.database_name not in self.client.list_database_names()
|
|
):
|
|
self.restore()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
if self.config.dump_on_exit:
|
|
self.dump()
|
|
|
|
async def drop(self):
|
|
"""Drop the database."""
|
|
await self.client.drop_database(self.config.database_name)
|
|
|
|
def restore(self):
|
|
"""Restore the database from the 'self.dump_directory'."""
|
|
subprocess.run(
|
|
["mongorestore", "--nsInclude", f"{self.config.database_name}.*", "--drop"],
|
|
shell=True,
|
|
check=True,
|
|
cwd=self.config.dump_path,
|
|
)
|
|
|
|
def dump(self):
|
|
"""Dump the database into 'self.dump_directory'."""
|
|
subprocess.run(
|
|
["mongodump", "--db", self.config.database_name], cwd=self.config.dump_path, check=True
|
|
)
|
|
|
|
async def set_parameter(self, name: str, value: any) -> None:
|
|
"""Set MongoDB Parameter."""
|
|
await self.client.admin.command({"setParameter": 1, name: value})
|
|
|
|
async def get_parameter(self, name: str) -> any:
|
|
return (await self.client.admin.command({"getParameter": 1, name: 1}))[name]
|
|
|
|
async def enable_sbe(self, state: bool) -> None:
|
|
"""Enable new query execution engine. Throw pymongo.errors.OperationFailure in case of failure."""
|
|
await self.set_parameter(
|
|
"internalQueryFrameworkControl", "trySbeEngine" if state else "forceClassicEngine"
|
|
)
|
|
|
|
async def explain(self, collection_name: str, pipeline: Pipeline) -> dict[str, any]:
|
|
"""Return explain for the given pipeline."""
|
|
return await self.database.command(
|
|
"explain",
|
|
{"aggregate": collection_name, "pipeline": pipeline, "cursor": {}},
|
|
verbosity="executionStats",
|
|
)
|
|
|
|
async def hide_index(self, collection_name: str, index_name: str) -> None:
|
|
"""Hide the given index from the query optimizer."""
|
|
await self.database.command(
|
|
{"collMod": collection_name, "index": {"name": index_name, "hidden": True}}
|
|
)
|
|
|
|
async def unhide_index(self, collection_name: str, index_name: str) -> None:
|
|
"""Make the given index visible for the query optimizer."""
|
|
await self.database.command(
|
|
{"collMod": collection_name, "index": {"name": index_name, "hidden": False}}
|
|
)
|
|
|
|
async def hide_all_indexes(self, collection_name: str) -> None:
|
|
"""Hide all indexes of the given collection from the query optimizer."""
|
|
for index in self.database[collection_name].list_indexes():
|
|
if index["name"] != "_id_":
|
|
await self.hide_index(collection_name, index["name"])
|
|
|
|
async def unhide_all_indexes(self, collection_name: str) -> None:
|
|
"""Make all indexes of the given collection visible fpr the query optimizer."""
|
|
for index in self.database[collection_name].list_indexes():
|
|
if index["name"] != "_id_":
|
|
await self.unhide_index(collection_name, index["name"])
|
|
|
|
async def drop_collection(self, collection_name: str) -> None:
|
|
"""Drop collection."""
|
|
await self.database[collection_name].drop()
|
|
|
|
async def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]]) -> None:
|
|
"""Insert documents into the collection with the given name."""
|
|
if len(docs) > 0:
|
|
await self.database[collection_name].insert_many(docs, ordered=False)
|
|
|
|
async def get_all_documents(self, collection_name: str):
|
|
"""Get all documents from the collection with the given name."""
|
|
return await self.database[collection_name].find({}).to_list(length=None)
|
|
|
|
async def get_stats(self, collection_name: str):
|
|
"""Get collection statistics."""
|
|
return await self.database.command("collstats", collection_name)
|
|
|
|
async def get_average_document_size(self, collection_name: str) -> float:
|
|
"""Get average document size for the given collection."""
|
|
stats = await self.get_stats(collection_name)
|
|
avg_size = stats.get("avgObjSize")
|
|
return avg_size if avg_size is not None else 0
|
|
|
|
|
|
class DatabaseParameter:
|
|
"""A utility class to work with MongoDB parameters."""
|
|
|
|
def __init__(self, database: DatabaseInstance, parameter_name: str) -> None:
|
|
"""Initialize the class."""
|
|
self.database = database
|
|
self.parameter_name = parameter_name
|
|
self.original_value = None
|
|
|
|
async def set(self, value):
|
|
"""Set the parameter's value."""
|
|
await self.database.set_parameter(self.parameter_name, value)
|
|
|
|
async def remember(self):
|
|
"""Store the current value of the parameter so it can be restored lately."""
|
|
self.original_value = await self.database.get_parameter(self.parameter_name)
|
|
|
|
async def restore(self):
|
|
"""Restore the remebered value of the parameter."""
|
|
if self.original_value is not None:
|
|
await self.set(self.original_value)
|
|
else:
|
|
raise ValueError(f'The parameter "{self.parameter_name}" has not been remembered.')
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_database_parameter(database: DatabaseInstance, parameter_name: str):
|
|
"""Create a new instance of a context manager on top of DatabaseParameter. It restores the original value on teardown. Useful when we need temporarily change a parameter."""
|
|
param = DatabaseParameter(database, parameter_name)
|
|
await param.remember()
|
|
try:
|
|
yield param
|
|
finally:
|
|
await param.restore()
|