diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index c9e48935f..1bd26a4e2 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -20,11 +20,10 @@ import ssl from typing import Any, Optional, TYPE_CHECKING -from aiofiles.tempfile import TemporaryDirectory - from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import TLSVersionError +from google.cloud.sql.connector.utils import AsyncTemporaryDirectory from google.cloud.sql.connector.utils import write_to_file if TYPE_CHECKING: @@ -108,7 +107,7 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont # tmpdir and its contents are automatically deleted after the CA cert # and ephemeral cert are loaded into the SSLcontext. The values # need to be written to files in order to be loaded by the SSLContext - async with TemporaryDirectory() as tmpdir: + async with AsyncTemporaryDirectory() as tmpdir: ca_filename, cert_filename, key_filename = await write_to_file( tmpdir, self.server_ca_cert, self.client_cert, self.private_key ) diff --git a/google/cloud/sql/connector/utils.py b/google/cloud/sql/connector/utils.py index dd0aec344..2c355d965 100755 --- a/google/cloud/sql/connector/utils.py +++ b/google/cloud/sql/connector/utils.py @@ -14,7 +14,9 @@ limitations under the License. """ -import aiofiles +import asyncio +import tempfile + from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -57,6 +59,33 @@ async def generate_keys() -> tuple[bytes, str]: return priv_key, pub_key +class AsyncTemporaryDirectory: + """Async context manager wrapper around tempfile.TemporaryDirectory.""" + + async def __aenter__(self) -> str: + self.temp_dir = await asyncio.to_thread(tempfile.TemporaryDirectory) + return self.temp_dir.name + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await asyncio.to_thread(self.temp_dir.cleanup) + + +def _write_files( + ca_filename: str, + serverCaCert: str, + cert_filename: str, + ephemeralCert: str, + key_filename: str, + priv_key: bytes, +) -> None: + with open(ca_filename, "w+") as ca_out: + ca_out.write(serverCaCert) + with open(cert_filename, "w+") as ephemeral_out: + ephemeral_out.write(ephemeralCert) + with open(key_filename, "wb") as priv_out: + priv_out.write(priv_key) + + async def write_to_file( dir_path: str, serverCaCert: str, ephemeralCert: str, priv_key: bytes ) -> tuple[str, str, str]: @@ -68,12 +97,15 @@ async def write_to_file( cert_filename = f"{dir_path}/cert.pem" key_filename = f"{dir_path}/priv.pem" - async with aiofiles.open(ca_filename, "w+") as ca_out: - await ca_out.write(serverCaCert) - async with aiofiles.open(cert_filename, "w+") as ephemeral_out: - await ephemeral_out.write(ephemeralCert) - async with aiofiles.open(key_filename, "wb") as priv_out: - await priv_out.write(priv_key) + await asyncio.to_thread( + _write_files, + ca_filename, + serverCaCert, + cert_filename, + ephemeralCert, + key_filename, + priv_key, + ) return (ca_filename, cert_filename, key_filename) diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..bceef3a1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "aiofiles", "aiohttp", "cryptography>=42.0.0", "dnspython>=2.0.0", diff --git a/tests/conftest.py b/tests/conftest.py index 83d7a78f3..309d3ceab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,13 +15,14 @@ """ import asyncio +import inspect import os import socket import ssl from threading import Thread from typing import Any, AsyncGenerator -from aiofiles.tempfile import TemporaryDirectory +import aiohttp from aiohttp import web from cryptography.hazmat.primitives import serialization import pytest # noqa F401 Needed to run the tests @@ -32,9 +33,23 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.utils import AsyncTemporaryDirectory from google.cloud.sql.connector.utils import generate_keys from google.cloud.sql.connector.utils import write_to_file +# Monkeypatch aiohttp.ClientResponse.__init__ to support aioresponses with aiohttp >= 3.11/3.14 +_original_client_response_init = aiohttp.ClientResponse.__init__ + + +def _patched_client_response_init(self, *args, **kwargs): + sig = inspect.signature(_original_client_response_init) + if "stream_writer" in sig.parameters and "stream_writer" not in kwargs: + kwargs["stream_writer"] = kwargs.get("writer", None) + return _original_client_response_init(self, *args, **kwargs) + + +aiohttp.ClientResponse.__init__ = _patched_client_response_init + SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -101,7 +116,7 @@ async def start_proxy_server(instance: FakeCSQLInstance) -> None: format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) - async with TemporaryDirectory() as tmpdir: + async with AsyncTemporaryDirectory() as tmpdir: server_filename, _, key_filename = await write_to_file( tmpdir, instance.server_cert_pem, "", server_key_bytes ) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 66bf64a32..42c5cc666 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -23,7 +23,6 @@ import ssl from typing import Any, Callable, Literal, Optional -from aiofiles.tempfile import TemporaryDirectory from aiohttp import web from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -36,6 +35,7 @@ from google.auth.credentials import TokenState from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN +from google.cloud.sql.connector.utils import AsyncTemporaryDirectory from google.cloud.sql.connector.utils import generate_keys from google.cloud.sql.connector.utils import write_to_file @@ -205,7 +205,7 @@ async def create_ssl_context(instance: FakeCSQLInstance) -> ssl.SSLContext: context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = False # load ssl.SSLContext with certs - async with TemporaryDirectory() as tmpdir: + async with AsyncTemporaryDirectory() as tmpdir: ca_filename, cert_filename, key_filename = await write_to_file( tmpdir, instance.server_cert_pem, ephemeral_cert, client_private )