mirror of https://github.com/mongodb/mongo
301 lines
10 KiB
Python
301 lines
10 KiB
Python
"""Helper tools to get OAuth credentials using the PKCE flow."""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
from random import choice
|
|
from string import ascii_lowercase
|
|
from typing import Any, Callable, Optional, Tuple
|
|
from urllib.parse import parse_qs, urlsplit
|
|
from webbrowser import open as web_open
|
|
|
|
import requests
|
|
from oauthlib.oauth2 import BackendApplicationClient
|
|
from pkce import generate_pkce_pair
|
|
from pydantic import ValidationError
|
|
from pydantic.main import BaseModel
|
|
from requests_oauthlib import OAuth2Session
|
|
from buildscripts.util.fileops import read_yaml_file
|
|
|
|
AUTH_HANDLER_RESPONSE = """\
|
|
<html>
|
|
<head>
|
|
<title>Authentication Status</title>
|
|
<script>
|
|
window.onload = function() {
|
|
window.close();
|
|
}
|
|
</script>
|
|
</head>
|
|
<body>
|
|
<p>The authentication flow has completed.</p>
|
|
</body>
|
|
</html>
|
|
""".encode("utf-8")
|
|
|
|
|
|
class Configs:
|
|
"""Collect configurations necessary for authentication process."""
|
|
|
|
# pylint: disable=invalid-name
|
|
|
|
AUTH_DOMAIN = "corp.mongodb.com/oauth2/aus4k4jv00hWjNnps297"
|
|
CLIENT_ID = "0oa5zf9ps4N3JKWIJ297"
|
|
REDIRECT_PORT = 8989
|
|
SCOPE = "kanopy+openid+profile"
|
|
|
|
def __init__(self, client_credentials_scope: str = None,
|
|
client_credentials_user_name: str = None, auth_domain: str = None,
|
|
client_id: str = None, redirect_port: int = None, scope: str = None):
|
|
"""Initialize configs instance."""
|
|
|
|
self.AUTH_DOMAIN = auth_domain or self.AUTH_DOMAIN
|
|
self.CLIENT_ID = client_id or self.CLIENT_ID
|
|
self.REDIRECT_PORT = redirect_port or self.REDIRECT_PORT
|
|
self.SCOPE = scope or self.SCOPE
|
|
self.CLIENT_CREDENTIALS_SCOPE = client_credentials_scope
|
|
self.CLIENT_CREDENTIALS_USER_NAME = client_credentials_user_name
|
|
|
|
|
|
class OAuthCredentials(BaseModel):
|
|
"""OAuth access token and its associated metadata."""
|
|
|
|
expires_in: int
|
|
access_token: str
|
|
created_time: datetime
|
|
user_name: str
|
|
|
|
def are_expired(self) -> bool:
|
|
"""
|
|
Check whether the current OAuth credentials are expired or not.
|
|
|
|
:return: Whether the credentials are expired or not.
|
|
"""
|
|
return self.created_time + timedelta(seconds=self.expires_in) < datetime.now()
|
|
|
|
@classmethod
|
|
def get_existing_credentials_from_file(cls, file_path: str) -> Optional[OAuthCredentials]:
|
|
"""
|
|
Try to get OAuth credentials from a file location.
|
|
|
|
Will return None if credentials either don't exist or are expired.
|
|
:param file_path: Location to check for OAuth credentials.
|
|
:return: Valid OAuth credentials or None if valid credentials don't exist
|
|
"""
|
|
try:
|
|
creds = OAuthCredentials(**read_yaml_file(file_path))
|
|
if (creds.access_token and creds.created_time and creds.expires_in and creds.user_name
|
|
and not creds.are_expired()):
|
|
return creds
|
|
else:
|
|
return None
|
|
except ValidationError:
|
|
return None
|
|
except OSError:
|
|
return None
|
|
|
|
|
|
class _RedirectServer(HTTPServer):
|
|
"""HTTP server to use when fetching OAuth credentials using the PKCE flow."""
|
|
|
|
pkce_credentials: Optional[OAuthCredentials] = None
|
|
auth_domain: str
|
|
client_id: str
|
|
redirect_uri: str
|
|
code_verifier: str
|
|
|
|
def __init__(
|
|
self,
|
|
server_address: Tuple[str, int],
|
|
handler: Callable[..., BaseHTTPRequestHandler],
|
|
redirect_uri: str,
|
|
auth_domain: str,
|
|
client_id: str,
|
|
code_verifier: str,
|
|
):
|
|
self.redirect_uri = redirect_uri
|
|
self.auth_domain = auth_domain
|
|
self.client_id = client_id
|
|
self.code_verifier = code_verifier
|
|
super().__init__(server_address, handler)
|
|
|
|
|
|
class _Handler(BaseHTTPRequestHandler):
|
|
"""Request handler class to use when trying to get OAuth credentials."""
|
|
|
|
# pylint: disable=invalid-name
|
|
|
|
server: _RedirectServer
|
|
|
|
def _set_response(self) -> None:
|
|
"""Set the response to the server making a request."""
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "text/html")
|
|
self.end_headers()
|
|
|
|
def log_message(self, log_format: Any, *args: Any) -> None: # pylint: disable=unused-argument,arguments-differ
|
|
"""
|
|
Log HTTP Server internal messages.
|
|
|
|
:param log_format: The format to use when logging messages.
|
|
:param args: Key word args.
|
|
"""
|
|
return None
|
|
|
|
def do_GET(self) -> None:
|
|
"""Handle the callback response from the auth server."""
|
|
params = parse_qs(urlsplit(self.path).query)
|
|
code = params.get("code")
|
|
|
|
if not code:
|
|
raise ValueError("Could not get authorization code when signing in to Okta")
|
|
|
|
url = f"https://{self.server.auth_domain}/v1/token"
|
|
body = {
|
|
"grant_type": "authorization_code",
|
|
"client_id": self.server.client_id,
|
|
"code_verifier": self.server.code_verifier,
|
|
"code": code,
|
|
"redirect_uri": self.server.redirect_uri,
|
|
}
|
|
|
|
resp = requests.post(url, data=body).json()
|
|
|
|
access_token = resp.get("access_token")
|
|
expires_in = resp.get("expires_in")
|
|
|
|
if not access_token or not expires_in:
|
|
raise ValueError("Could not get access token or expires_in data about access token")
|
|
|
|
headers = {"Authorization": f"Bearer {access_token}"}
|
|
resp = requests.get(f"https://{self.server.auth_domain}/v1/userinfo",
|
|
headers=headers).json()
|
|
|
|
split_username = resp["preferred_username"].split("@")
|
|
|
|
if len(split_username) != 2:
|
|
raise ValueError("Could not get user_name of current user")
|
|
|
|
self.server.pkce_credentials = OAuthCredentials(
|
|
access_token=access_token,
|
|
expires_in=expires_in,
|
|
created_time=datetime.now(),
|
|
user_name=split_username[0],
|
|
)
|
|
self._set_response()
|
|
self.wfile.write(AUTH_HANDLER_RESPONSE)
|
|
|
|
|
|
class PKCEOauthTools:
|
|
"""Basic toolset to get OAuth credentials using the PKCE flow."""
|
|
|
|
auth_domain: str
|
|
client_id: str
|
|
redirect_port: int
|
|
redirect_uri: str
|
|
scope: str
|
|
|
|
def __init__(self, auth_domain: str, client_id: str, redirect_port: int, scope: str):
|
|
"""
|
|
Create a new PKCEOauth tools instance.
|
|
|
|
:param auth_domain: The uri of the auth server to get the credentials from.
|
|
:param client_id: The id of the client that you are using to authenticate.
|
|
:param redirect_port: Port to use when setting up the local server for the auth redirect.
|
|
:param scope: The OAuth scopes to request access for.
|
|
"""
|
|
self.auth_domain = auth_domain
|
|
self.client_id = client_id
|
|
self.redirect_port = redirect_port
|
|
self.redirect_uri = f"http://localhost:{redirect_port}/"
|
|
self.scope = scope
|
|
|
|
def get_pkce_credentials(self, print_auth_url: bool = False) -> OAuthCredentials:
|
|
"""
|
|
Try to get an OAuth access token and its associated metadata.
|
|
|
|
:param print_auth_url: Whether to print the auth url to the console instead of opening it.
|
|
:return: OAuth credentials and some associated metadata to check if they have expired.
|
|
"""
|
|
code_verifier, code_challenge = generate_pkce_pair()
|
|
|
|
state = "".join(choice(ascii_lowercase) for i in range(10))
|
|
|
|
authorization_url = (f"https://{self.auth_domain}/v1/authorize?"
|
|
f"scope={self.scope}&"
|
|
f"response_type=code&"
|
|
f"response_mode=query&"
|
|
f"client_id={self.client_id}&"
|
|
f"code_challenge={code_challenge}&"
|
|
f"state={state}&"
|
|
f"code_challenge_method=S256&"
|
|
f"redirect_uri={self.redirect_uri}")
|
|
|
|
httpd = _RedirectServer(
|
|
("", self.redirect_port),
|
|
_Handler,
|
|
self.redirect_uri,
|
|
self.auth_domain,
|
|
self.client_id,
|
|
code_verifier,
|
|
)
|
|
if print_auth_url:
|
|
print("Please open the below url in a browser and sign in if necessary")
|
|
print(authorization_url)
|
|
else:
|
|
web_open(authorization_url)
|
|
httpd.handle_request()
|
|
|
|
if not httpd.pkce_credentials:
|
|
raise ValueError(
|
|
"Could not retrieve Okta credentials to talk to Kanopy with. "
|
|
"Please sign out of Okta in your browser and try runnning this script again")
|
|
|
|
return httpd.pkce_credentials
|
|
|
|
|
|
def get_oauth_credentials(configs: Configs, print_auth_url: bool = False) -> OAuthCredentials:
|
|
"""
|
|
Run the OAuth workflow to get credentials for a human user.
|
|
|
|
:param configs: Configs instance.
|
|
:param print_auth_url: Whether to print the auth url to the console instead of opening it.
|
|
:return: OAuth credentials for the given user.
|
|
"""
|
|
oauth_tools = PKCEOauthTools(auth_domain=configs.AUTH_DOMAIN, client_id=configs.CLIENT_ID,
|
|
redirect_port=configs.REDIRECT_PORT, scope=configs.SCOPE)
|
|
credentials = oauth_tools.get_pkce_credentials(print_auth_url)
|
|
return credentials
|
|
|
|
|
|
def get_client_cred_oauth_credentials(client_id: str, client_secret: str,
|
|
configs: Configs) -> OAuthCredentials:
|
|
"""
|
|
Run the OAuth workflow to get credentials for a machine user.
|
|
|
|
:param client_id: The client_id of the machine user to authenticate as.
|
|
:param client_secret: The client_secret of the machine user to authenticate as.
|
|
:param configs: Configs instance.
|
|
:return: OAuth credentials for the given machine user.
|
|
"""
|
|
client = BackendApplicationClient(client_id=client_id)
|
|
oauth = OAuth2Session(client=client)
|
|
token = oauth.fetch_token(
|
|
token_url=f"https://{configs.AUTH_DOMAIN}/v1/token",
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
scope=configs.CLIENT_CREDENTIALS_SCOPE,
|
|
)
|
|
access_token = token.get("access_token")
|
|
expires_in = token.get("expires_in")
|
|
|
|
if not access_token or not expires_in:
|
|
raise ValueError("Could not get access token or expires_in data about access token")
|
|
|
|
return OAuthCredentials(
|
|
access_token=access_token,
|
|
expires_in=expires_in,
|
|
created_time=datetime.now(),
|
|
user_name=configs.CLIENT_CREDENTIALS_USER_NAME,
|
|
)
|