Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import ssl
from typing import Any, Optional, TYPE_CHECKING

from aiofiles.tempfile import TemporaryDirectory

from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import TLSVersionError
from google.cloud.sql.connector.utils import AsyncTemporaryDirectory
from google.cloud.sql.connector.utils import write_to_file

if TYPE_CHECKING:
Expand Down Expand Up @@ -108,7 +107,7 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont
# tmpdir and its contents are automatically deleted after the CA cert
# and ephemeral cert are loaded into the SSLcontext. The values
# need to be written to files in order to be loaded by the SSLContext
async with TemporaryDirectory() as tmpdir:
async with AsyncTemporaryDirectory() as tmpdir:
ca_filename, cert_filename, key_filename = await write_to_file(
tmpdir, self.server_ca_cert, self.client_cert, self.private_key
)
Expand Down
46 changes: 39 additions & 7 deletions google/cloud/sql/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
limitations under the License.
"""

import aiofiles
import asyncio
import tempfile

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
Expand Down Expand Up @@ -57,6 +59,33 @@ async def generate_keys() -> tuple[bytes, str]:
return priv_key, pub_key


class AsyncTemporaryDirectory:
"""Async context manager wrapper around tempfile.TemporaryDirectory."""

async def __aenter__(self) -> str:
self.temp_dir = await asyncio.to_thread(tempfile.TemporaryDirectory)
return self.temp_dir.name

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await asyncio.to_thread(self.temp_dir.cleanup)


def _write_files(
ca_filename: str,
serverCaCert: str,
cert_filename: str,
ephemeralCert: str,
key_filename: str,
priv_key: bytes,
) -> None:
with open(ca_filename, "w+") as ca_out:
ca_out.write(serverCaCert)
with open(cert_filename, "w+") as ephemeral_out:
ephemeral_out.write(ephemeralCert)
with open(key_filename, "wb") as priv_out:
priv_out.write(priv_key)


async def write_to_file(
dir_path: str, serverCaCert: str, ephemeralCert: str, priv_key: bytes
) -> tuple[str, str, str]:
Expand All @@ -68,12 +97,15 @@ async def write_to_file(
cert_filename = f"{dir_path}/cert.pem"
key_filename = f"{dir_path}/priv.pem"

async with aiofiles.open(ca_filename, "w+") as ca_out:
await ca_out.write(serverCaCert)
async with aiofiles.open(cert_filename, "w+") as ephemeral_out:
await ephemeral_out.write(ephemeralCert)
async with aiofiles.open(key_filename, "wb") as priv_out:
await priv_out.write(priv_key)
await asyncio.to_thread(
_write_files,
ca_filename,
serverCaCert,
cert_filename,
ephemeralCert,
key_filename,
priv_key,
)

return (ca_filename, cert_filename, key_filename)

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"aiofiles",
"aiohttp",
"cryptography>=42.0.0",
"dnspython>=2.0.0",
Expand Down
19 changes: 17 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
"""

import asyncio
import inspect
import os
import socket
import ssl
from threading import Thread
from typing import Any, AsyncGenerator

from aiofiles.tempfile import TemporaryDirectory
import aiohttp
from aiohttp import web
from cryptography.hazmat.primitives import serialization
import pytest # noqa F401 Needed to run the tests
Expand All @@ -32,9 +33,23 @@
from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.utils import AsyncTemporaryDirectory
from google.cloud.sql.connector.utils import generate_keys
from google.cloud.sql.connector.utils import write_to_file

# Monkeypatch aiohttp.ClientResponse.__init__ to support aioresponses with aiohttp >= 3.11/3.14
_original_client_response_init = aiohttp.ClientResponse.__init__


def _patched_client_response_init(self, *args, **kwargs):
sig = inspect.signature(_original_client_response_init)
if "stream_writer" in sig.parameters and "stream_writer" not in kwargs:
kwargs["stream_writer"] = kwargs.get("writer", None)
return _original_client_response_init(self, *args, **kwargs)


aiohttp.ClientResponse.__init__ = _patched_client_response_init

SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"]


Expand Down Expand Up @@ -101,7 +116,7 @@ async def start_proxy_server(instance: FakeCSQLInstance) -> None:
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
async with TemporaryDirectory() as tmpdir:
async with AsyncTemporaryDirectory() as tmpdir:
server_filename, _, key_filename = await write_to_file(
tmpdir, instance.server_cert_pem, "", server_key_bytes
)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import ssl
from typing import Any, Callable, Literal, Optional

from aiofiles.tempfile import TemporaryDirectory
from aiohttp import web
from cryptography import x509
from cryptography.hazmat.backends import default_backend
Expand All @@ -36,6 +35,7 @@
from google.auth.credentials import TokenState

from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN
from google.cloud.sql.connector.utils import AsyncTemporaryDirectory
from google.cloud.sql.connector.utils import generate_keys
from google.cloud.sql.connector.utils import write_to_file

Expand Down Expand Up @@ -205,7 +205,7 @@ async def create_ssl_context(instance: FakeCSQLInstance) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
# load ssl.SSLContext with certs
async with TemporaryDirectory() as tmpdir:
async with AsyncTemporaryDirectory() as tmpdir:
ca_filename, cert_filename, key_filename = await write_to_file(
tmpdir, instance.server_cert_pem, ephemeral_cert, client_private
)
Expand Down
Loading