"""
Credential encryption utilities for cache storage.
Provides Fernet symmetric encryption for proxy credentials at rest (L2/L3 tiers).
Uses environment variable PROXYWHIRL_CACHE_ENCRYPTION_KEY for key management.
Supports key rotation via MultiFernet with PROXYWHIRL_CACHE_KEY_PREVIOUS.
"""
from __future__ import annotations
import os
from cryptography.fernet import Fernet, MultiFernet
from pydantic import SecretStr
__all__ = [
"CredentialEncryptor",
"get_encryption_keys",
"create_multi_fernet",
"rotate_key",
]
[docs]
def get_encryption_keys() -> list[bytes]:
"""Get all valid encryption keys for MultiFernet.
Returns keys in priority order: current key first, then previous key.
This allows decryption of data encrypted with either key while
always encrypting new data with the current (first) key.
Returns:
List of Fernet keys as bytes. Always contains at least one key.
First key is current, subsequent keys are for backward compatibility.
Raises:
ValueError: If any key has invalid Fernet format
Example:
>>> keys = get_encryption_keys()
>>> len(keys) # 1 or 2 depending on env vars
1
"""
keys: list[bytes] = []
# Get current key
current_key_str = os.environ.get("PROXYWHIRL_CACHE_ENCRYPTION_KEY")
if current_key_str:
current_key = (
current_key_str.encode("utf-8") if isinstance(current_key_str, str) else current_key_str
)
try:
# Validate key format
Fernet(current_key)
keys.append(current_key)
except Exception as e:
raise ValueError(
f"Invalid Fernet key format in PROXYWHIRL_CACHE_ENCRYPTION_KEY: {e}. "
"Key must be 32 url-safe base64-encoded bytes. "
"Generate a valid key with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'"
) from e
else:
# No current key, generate one
keys.append(Fernet.generate_key())
# Get previous key if exists
previous_key_str = os.environ.get("PROXYWHIRL_CACHE_KEY_PREVIOUS")
if previous_key_str:
previous_key = (
previous_key_str.encode("utf-8")
if isinstance(previous_key_str, str)
else previous_key_str
)
try:
# Validate key format
Fernet(previous_key)
# Only add if different from current key
if previous_key != keys[0]:
keys.append(previous_key)
except Exception as e:
raise ValueError(
f"Invalid Fernet key format in PROXYWHIRL_CACHE_KEY_PREVIOUS: {e}. "
"Key must be 32 url-safe base64-encoded bytes. "
"Generate a valid key with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'"
) from e
return keys
[docs]
def create_multi_fernet() -> MultiFernet:
"""Create MultiFernet instance with all valid encryption keys.
MultiFernet tries keys in order for decryption (newest first).
All new encryptions use the first (current) key.
Returns:
MultiFernet instance configured with current and previous keys
Raises:
ValueError: If any key has invalid format
Example:
>>> mf = create_multi_fernet()
>>> encrypted = mf.encrypt(b"secret")
>>> mf.decrypt(encrypted)
b'secret'
"""
keys = get_encryption_keys()
fernets = [Fernet(key) for key in keys]
return MultiFernet(fernets)
[docs]
def rotate_key(new_key: str) -> None:
"""Rotate encryption keys by setting new current key.
This function updates environment variables to perform key rotation:
- Current key moves to PROXYWHIRL_CACHE_KEY_PREVIOUS
- New key becomes PROXYWHIRL_CACHE_ENCRYPTION_KEY
This allows gradual migration: new data uses new key, old data
can still be decrypted with previous key.
Args:
new_key: New Fernet key as base64-encoded string
Raises:
ValueError: If new_key has invalid Fernet format
Example:
>>> from cryptography.fernet import Fernet
>>> new_key = Fernet.generate_key().decode()
>>> rotate_key(new_key)
"""
# Validate new key format
try:
new_key_bytes = new_key.encode("utf-8")
Fernet(new_key_bytes)
except Exception as e:
raise ValueError(
f"Invalid new key format: {e}. "
"Key must be 32 url-safe base64-encoded bytes. "
"Generate a valid key with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'"
) from e
# Move current key to previous
current_key = os.environ.get("PROXYWHIRL_CACHE_ENCRYPTION_KEY")
if current_key:
os.environ["PROXYWHIRL_CACHE_KEY_PREVIOUS"] = current_key
# Set new key as current
os.environ["PROXYWHIRL_CACHE_ENCRYPTION_KEY"] = new_key
[docs]
class CredentialEncryptor:
"""
Handles encryption/decryption of proxy credentials with key rotation support.
Uses Fernet symmetric encryption (AES-128-CBC + HMAC) to protect
credentials stored in L2 (JSONL files) and L3 (SQLite database).
Supports gradual key rotation via MultiFernet using both current
and previous keys.
Example:
>>> encryptor = CredentialEncryptor()
>>> encrypted = encryptor.encrypt(SecretStr("mypassword"))
>>> decrypted = encryptor.decrypt(encrypted)
>>> decrypted.get_secret_value()
'mypassword'
"""
def __init__(self, key: bytes | None = None) -> None:
"""Initialize encryptor with Fernet key or MultiFernet.
Args:
key: Optional Fernet key (32 url-safe base64-encoded bytes).
If None, uses get_encryption_keys() to load current and
previous keys from environment variables. If no env vars set,
generates a new key (WARNING: regenerated keys cannot decrypt
existing cached data).
Raises:
ValueError: If provided key is invalid for Fernet
"""
if key is None:
# Use MultiFernet with all available keys
self._cipher = create_multi_fernet()
# Store the current (first) key for backward compatibility
keys = get_encryption_keys()
self.key = keys[0]
else:
# Single key provided, use regular Fernet
try:
self._cipher = Fernet(key)
self.key = key
except Exception as e:
raise ValueError(f"Invalid Fernet key: {e}") from e
[docs]
def encrypt(self, secret: SecretStr) -> bytes:
"""Encrypt a SecretStr to bytes.
Args:
secret: SecretStr containing plaintext to encrypt
Returns:
Encrypted bytes suitable for storage in BLOB fields
Raises:
ValueError: If encryption fails
"""
if not secret:
return b""
plaintext = secret.get_secret_value()
if not plaintext:
return b""
try:
return self._cipher.encrypt(plaintext.encode("utf-8"))
except Exception as e:
raise ValueError(f"Encryption failed: {e}") from e
[docs]
def decrypt(self, encrypted: bytes) -> SecretStr:
"""Decrypt encrypted bytes back to SecretStr.
Args:
encrypted: Encrypted bytes from storage
Returns:
SecretStr containing decrypted plaintext (never logs value)
Raises:
ValueError: If decryption fails (wrong key, corrupted data)
"""
if not encrypted:
return SecretStr("")
try:
plaintext_bytes = self._cipher.decrypt(encrypted)
plaintext = plaintext_bytes.decode("utf-8")
return SecretStr(plaintext)
except Exception as e:
raise ValueError(f"Decryption failed: {e}") from e