diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 11508ce17..92259ac6b 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -146,7 +146,7 @@ async def _get_metadata( ) ip_addresses = ( - {ip["type"]: ip["ipAddress"] for ip in ret_dict["ipAddresses"]} + {ip["type"]: [ip["ipAddress"]] for ip in ret_dict["ipAddresses"]} if "ipAddresses" in ret_dict else {} ) @@ -156,20 +156,22 @@ async def _get_metadata( if ret_dict.get("pscEnabled"): # Find PSC instance DNS name in the dns_names field psc_dns_names = [ - d["name"] + d["name"].rstrip(".") for d in ret_dict.get("dnsNames", []) if d["connectionType"] == "PRIVATE_SERVICE_CONNECT" and d["dnsScope"] == "INSTANCE" ] - dns_name = psc_dns_names[0] if psc_dns_names else None + # Sort: .sql-psc.goog first + psc_dns_names.sort(key=lambda x: x.endswith(".sql-psc.goog"), reverse=True) # Fall back do dns_name field if dns_names is not set - if dns_name is None: + if not psc_dns_names: dns_name = ret_dict.get("dnsName", None) + if dns_name: + psc_dns_names = [dns_name.rstrip(".")] - # Remove trailing period from DNS name. Required for SSL in Python - if dns_name: - ip_addresses["PSC"] = dns_name.rstrip(".") + if psc_dns_names: + ip_addresses["PSC"] = psc_dns_names return { "ip_addresses": ip_addresses, @@ -177,6 +179,43 @@ async def _get_metadata( "database_version": ret_dict["databaseVersion"], } + async def resolve_connect_settings( + self, + dns_name: str, + location: str, + ) -> dict[str, Any]: + """Asynchronously calls the resolveConnectSettings endpoint to resolve a + PSC DNS name to a connection name. + + Args: + dns_name (str): The DNS name of the Cloud SQL instance. + location (str): The region/location of the instance. + + Returns: + A dictionary containing the resolve response (e.g. connectionName). + """ + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + + url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/dns/{dns_name}/locations/{location}:resolveConnectSettings" + + resp = await self._client.get(url, headers=headers) + if resp.status >= 500: + resp = await retry_50x(self._client.get, url, headers=headers) + try: + ret_dict = await resp.json() + if resp.status >= 400: + message = ret_dict.get("error", {}).get("message") + if message: + resp.reason = message + except Exception: + pass + finally: + resp.raise_for_status() + + return ret_dict + async def _get_ephemeral( self, project: str, diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index c9e48935f..9e160ac6f 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -122,6 +122,17 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str: """Returns the first IP address for the instance, according to the preference supplied by ip_type. If no IP addressess with the given preference are found, an error is raised.""" + if ip_type.value in self.ip_addrs: + return self.ip_addrs[ip_type.value][0] + raise CloudSQLIPTypeError( + "Cloud SQL instance does not have any IP addresses matching " + f"preference: {ip_type.value}" + ) + + def get_preferred_ips(self, ip_type: IPTypes) -> list[str]: + """Returns all IP addresses for the instance, according to the preference + supplied by ip_type. If no IP addressess with the given preference are found, + an error is raised.""" if ip_type.value in self.ip_addrs: return self.ip_addrs[ip_type.value] raise CloudSQLIPTypeError( diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 798969c2c..326f06d62 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -18,6 +18,7 @@ import asyncio from functools import partial +import ipaddress import logging import os import socket @@ -49,6 +50,27 @@ logger = logging.getLogger(name=__name__) + +def _is_ip_address(ip: str) -> bool: + try: + ipaddress.ip_address(ip) + return True + except ValueError: + return False + + +def _get_fallback_ips( + current_ips: list[str], ip_addresses: dict[str, list[str]] +) -> list[str]: + if not current_ips: + return current_ips + if _is_ip_address(current_ips[0]): + return current_ips + fallback = ip_addresses.get("PRIVATE") + if not fallback: + fallback = ip_addresses.get("PRIMARY") + return fallback if fallback else current_ips + ASYNC_DRIVERS = ["asyncpg"] SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" @@ -316,6 +338,8 @@ async def connect_async( user_agent=self._user_agent, driver=driver, ) + if hasattr(self._resolver, "set_client"): + self._resolver.set_client(self._client) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) conn_name = await self._resolver.resolve(instance_connection_string) @@ -384,13 +408,14 @@ async def connect_async( conn_info = await monitored_cache.connect_info() # validate driver matches intended database engine DriverMapping.validate_engine(driver, conn_info.database_version) - ip_address = conn_info.get_preferred_ip(ip_type) + preferred_ips = conn_info.get_preferred_ips(ip_type) except Exception: # with an error from Cloud SQL Admin API call or IP type, invalidate # the cache and re-raise the error await self._remove_cached(str(conn_name), enable_iam_auth) raise + targets = [] # If the connector is configured with a custom DNS name, attempt to use # that DNS name to connect to the instance. Fall back to the metadata IP # address if the DNS name does not resolve to an IP address. @@ -398,26 +423,35 @@ async def connect_async( try: ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) if ips: - ip_address = ips[0] + targets.extend(ips) logger.debug( f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " + f"'{conn_info.conn_name.domain_name}' resolved to '{ips}', " "using it to connect" ) else: + fallback_ips = _get_fallback_ips( + preferred_ips, conn_info.ip_addrs + ) logger.debug( f"['{instance_connection_string}']: Custom DNS name " f"'{conn_info.conn_name.domain_name}' resolved but returned no " - f"entries, using '{ip_address}' from instance metadata" + f"entries, using '{fallback_ips[0]}' from instance metadata" ) + targets.extend(fallback_ips) except Exception as e: + fallback_ips = _get_fallback_ips( + preferred_ips, conn_info.ip_addrs + ) logger.debug( f"['{instance_connection_string}']: Custom DNS name " f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " - f"address: {e}, using '{ip_address}' from instance metadata" + f"address: {e}, using '{fallback_ips[0]}' from instance metadata" ) + targets.extend(fallback_ips) + else: + targets.extend(preferred_ips) - logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn if enable_iam_auth: formatted_user = format_database_user( @@ -428,32 +462,56 @@ async def connect_async( f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" ) kwargs["user"] = formatted_user + try: - # async drivers are unblocking and can be awaited directly - if driver in ASYNC_DRIVERS: - return await connector( - ip_address, - await conn_info.create_ssl_context(enable_iam_auth), - **kwargs, - ) - # Create socket with SSLContext for sync drivers - ctx = await conn_info.create_ssl_context(enable_iam_auth) - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - # If this connection was opened using a domain name, then store it - # for later in case we need to forcibly close it on failover. - if conn_info.conn_name.domain_name: - monitored_cache.sockets.append(sock) - # Synchronous drivers are blocking and run using executor - connect_partial = partial( - connector, - ip_address, - sock, - **kwargs, - ) - return await self._loop.run_in_executor(None, connect_partial) + last_ex = None + for target_ip in targets: + logger.debug(f"['{conn_info.conn_name}']: Connecting to {target_ip}:3307") + try: + # async drivers are unblocking and can be awaited directly + if driver in ASYNC_DRIVERS: + conn = await connector( + target_ip, + await conn_info.create_ssl_context(enable_iam_auth), + **kwargs, + ) + last_ex = None + return conn + + # Create socket with SSLContext for sync drivers + ctx = await conn_info.create_ssl_context(enable_iam_auth) + raw_sock = socket.create_connection((target_ip, SERVER_PROXY_PORT)) + try: + sock = ctx.wrap_socket( + raw_sock, + server_hostname=target_ip, + ) + except Exception: + raw_sock.close() + raise + + # If this connection was opened using a domain name, then store it + # for later in case we need to forcibly close it on failover. + if conn_info.conn_name.domain_name: + monitored_cache.sockets.append(sock) + # Synchronous drivers are blocking and run using executor + connect_partial = partial( + connector, + target_ip, + sock, + **kwargs, + ) + conn = await self._loop.run_in_executor(None, connect_partial) + last_ex = None + return conn + except Exception as e: + logger.debug( + f"['{conn_info.conn_name}']: Connection to {target_ip} failed: {e}" + ) + last_ex = e + + if last_ex: + raise last_ex except Exception: # with any exception, we attempt a force refresh, then throw the error diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index e255f328a..51e685426 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +import re +from typing import Any, List import dns.asyncresolver @@ -24,6 +25,10 @@ from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError +PSC_DNS_PATTERN = re.compile( + r"^([a-f0-9]{12})\.([^.]+)\.([a-z0-9]+-[a-z0-9]+)\.(sql|sql-psa|sql-psc)\.goog\.?$" +) + class DefaultResolver: """DefaultResolver simply validates and parses instance connection name.""" @@ -38,54 +43,115 @@ class DnsResolver(dns.asyncresolver.Resolver): TXT records in DNS. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._client: Any = None + + def set_client(self, client: Any) -> None: + self._client = client + async def resolve(self, dns: str) -> ConnectionName: # type: ignore - try: - conn_name = _parse_connection_name(dns) - except ValueError: - # The connection name was not project:region:instance format. - # Check if connection name is a valid DNS domain name - if _is_valid_domain(dns): - # Attempt to query a TXT record to get connection name. - conn_name = await self.query_dns(dns) - else: + current = dns + visited = {current} + + # Max 10 iterations to prevent infinite CNAME loops + for _ in range(10): + try: + domain_val = dns if current != dns else "" + conn_name = _parse_connection_name_with_domain_name( + current, domain_val + ) + return conn_name + except ValueError: + pass + + dns_normalized = current.rstrip(".") + match = PSC_DNS_PATTERN.match(dns_normalized.lower()) + if match: + region = match.group(3) + if self._client is None: + raise ValueError( + "SQLAdmin client is not configured in the resolver." + ) + + dns_name_with_dot = dns_normalized + "." + resp = await self._client.resolve_connect_settings( + dns_name_with_dot, region + ) + resolved_conn_name = resp["connectionName"] + return _parse_connection_name_with_domain_name( + resolved_conn_name, dns + ) + + if not _is_valid_domain(current): raise ValueError( "Arg `instance_connection_string` must have " "format: PROJECT:REGION:INSTANCE or be a valid DNS domain " f"name, got {dns}." ) - return conn_name - async def resolve_a_record(self, dns: str) -> List[str]: - try: - # Attempt to query the A records. - records = await super().resolve(dns, "A", raise_on_no_answer=True) - # return IP addresses as strings - return [record.to_text() for record in records] - except Exception: - # On any error, return empty list - return [] + cname_found = False + try: + cname = await self.resolve_cname(current) + cname_found = True + except DnsResolutionError: + pass + + if cname_found: + if cname in visited: + raise DnsResolutionError(f"CNAME loop detected for `{dns}`") + visited.add(cname) + current = cname + continue + + try: + rdata = await self.resolve_txt(current) + except DnsResolutionError as e: + raise DnsResolutionError( + f"Unable to resolve TXT record for `{dns}`" + ) from e - async def query_dns(self, dns: str) -> ConnectionName: - try: - # Attempt to query the TXT records. - records = await super().resolve(dns, "TXT", raise_on_no_answer=True) - # Sort the TXT record values alphabetically, strip quotes as record - # values can be returned as raw strings - rdata = [record.to_text().strip('"') for record in records] rdata.sort() - # Attempt to parse records, returning the first valid record. for record in rdata: try: - conn_name = _parse_connection_name_with_domain_name(record, dns) + conn_name = _parse_connection_name_with_domain_name( + record, dns + ) return conn_name except Exception: continue - # If all records failed to parse, throw error + raise DnsResolutionError( - f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`" + f"Unable to parse TXT record for `{current}` -> `{rdata[0]}`" + if rdata + else f"Unable to resolve TXT record for `{current}`" ) - # Don't override above DnsResolutionError - except DnsResolutionError: - raise + + raise DnsResolutionError( + f"CNAME loop detected or max resolution depth reached for `{dns}`" + ) + + async def resolve_cname(self, dns: str) -> str: + try: + answers = await super().resolve(dns, "CNAME", raise_on_no_answer=True) + return str(answers[0].target).rstrip(".") + except Exception as e: + raise DnsResolutionError( + f"Unable to resolve CNAME record for `{dns}`" + ) from e + + async def resolve_txt(self, dns: str) -> List[str]: + try: + answers = await super().resolve(dns, "TXT", raise_on_no_answer=True) + return [record.to_text().strip('"') for record in answers] except Exception as e: - raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e + raise DnsResolutionError( + f"Unable to resolve TXT record for `{dns}`" + ) from e + + async def resolve_a_record(self, dns: str) -> List[str]: + try: + records = await super().resolve(dns, "A", raise_on_no_answer=True) + return [record.to_text() for record in records] + except Exception: + return [] diff --git a/requirements-test.txt b/requirements-test.txt index 296878dd8..1e07996cc 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -11,3 +11,4 @@ asyncpg==0.31.0 python-tds==1.17.1 aioresponses==0.7.8 pytest-aiohttp==1.1.0 +aiohttp<3.14 diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 66bf64a32..a61ded9b7 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -227,6 +227,7 @@ def __init__( "PRIMARY": "127.0.0.1", "PRIVATE": "10.0.0.1", }, + dns_names: list = ["abcde.12345.us-central1.sql.goog"], legacy_dns_name: bool = False, cert_before: datetime = datetime.datetime.now(datetime.timezone.utc), cert_expiration: datetime = datetime.datetime.now(datetime.timezone.utc) @@ -237,6 +238,7 @@ def __init__( self.name = name self.db_version = db_version self.ip_addrs = ip_addrs + self.dns_names = dns_names self.psc_enabled = False self.cert_before = cert_before self.cert_expiration = cert_expiration @@ -265,14 +267,15 @@ async def connect_settings(self, request: Any) -> web.Response: "databaseVersion": self.db_version, } if self.legacy_dns_name: - response["dnsName"] = "abcde.12345.us-central1.sql.goog" + response["dnsName"] = self.dns_names[0] if self.dns_names else None else: response["dnsNames"] = [ { - "name": "abcde.12345.us-central1.sql.goog", + "name": name, "connectionType": "PRIVATE_SERVICE_CONNECT", "dnsScope": "INSTANCE", } + for name in self.dns_names ] return web.Response(content_type="application/json", body=json.dumps(response)) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index cfe509470..ecccdbd51 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -38,8 +38,8 @@ async def test_get_metadata_no_psc(fake_client: CloudSQLClient) -> None: ) assert resp["database_version"] == "POSTGRES_15" assert resp["ip_addresses"] == { - "PRIMARY": "127.0.0.1", - "PRIVATE": "10.0.0.1", + "PRIMARY": ["127.0.0.1"], + "PRIVATE": ["10.0.0.1"], } assert isinstance(resp["server_ca_cert"], str) @@ -58,9 +58,9 @@ async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None: ) assert resp["database_version"] == "POSTGRES_15" assert resp["ip_addresses"] == { - "PRIMARY": "127.0.0.1", - "PRIVATE": "10.0.0.1", - "PSC": "abcde.12345.us-central1.sql.goog", + "PRIMARY": ["127.0.0.1"], + "PRIVATE": ["10.0.0.1"], + "PSC": ["abcde.12345.us-central1.sql.goog"], } assert isinstance(resp["server_ca_cert"], str) @@ -80,9 +80,9 @@ async def test_get_metadata_legacy_dns_with_psc(fake_client: CloudSQLClient) -> ) assert resp["database_version"] == "POSTGRES_15" assert resp["ip_addresses"] == { - "PRIMARY": "127.0.0.1", - "PRIVATE": "10.0.0.1", - "PSC": "abcde.12345.us-central1.sql.goog", + "PRIMARY": ["127.0.0.1"], + "PRIVATE": ["10.0.0.1"], + "PSC": ["abcde.12345.us-central1.sql.goog"], } assert isinstance(resp["server_ca_cert"], str) @@ -290,3 +290,35 @@ async def test_get_ephemeral_error_parsing_json( assert exc_info.value.status == 404 assert exc_info.value.message == "Not Found" await client.close() + + +@pytest.mark.asyncio +async def test_get_metadata_multiple_psc_dns_sorted(fake_client: CloudSQLClient) -> None: + """ + Test _get_metadata returns successfully with multiple PSC IP types sorted. + """ + fake_client.instance.psc_enabled = True + fake_client.instance.legacy_dns_name = False + fake_client.instance.dns_names = [ + "dns1.sql.goog", + "dns2.sql-psc.goog", + "dns3.sql.goog", + ] + try: + resp = await fake_client._get_metadata( + "test-project", + "test-region", + "test-instance", + ) + assert resp["database_version"] == "POSTGRES_15" + assert resp["ip_addresses"] == { + "PRIMARY": ["127.0.0.1"], + "PRIVATE": ["10.0.0.1"], + "PSC": ["dns2.sql-psc.goog", "dns1.sql.goog", "dns3.sql.goog"], + } + assert isinstance(resp["server_ca_cert"], str) + finally: + fake_client.instance.psc_enabled = False + fake_client.instance.legacy_dns_name = False + fake_client.instance.dns_names = ["abcde.12345.us-central1.sql.goog"] + diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a09b5b72f..a37246f36 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -659,3 +659,63 @@ async def test_Connector_connect_async_custom_dns_resolver_fallback( fake_client.instance.ip_addrs = original_ips +@pytest.mark.asyncio +async def test_Connector_connect_async_custom_dns_resolver_fallback_psc_to_private_ip( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that Connector.connect_async falls back to Private IP if CNAME/PSC DNS resolution fails.""" + + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record" + ) as mock_resolve_a: + # DNS resolution fails + mock_resolve_a.return_value = [] + + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve" + ) as mock_resolve: + conn_name_with_domain = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + mock_resolve.return_value = conn_name_with_domain + + async with Connector( + credentials=fake_credentials, + loop=asyncio.get_running_loop(), + resolver=DnsResolver, + ip_type="PSC", # Use PSC IP type + ) as connector: + connector._client = fake_client + + original_ips = fake_client.instance.ip_addrs + # Configure instance to be PSC enabled, but also have a PRIVATE IP fallback! + fake_client.instance.psc_enabled = True + fake_client.instance.ip_addrs = { + "PSC": "1ad3b5d73f10.3oxon2yfo9tob.us-east1.sql.goog", + "PRIVATE": "10.0.0.1", + } + + try: + with patch( + "google.cloud.sql.connector.asyncpg.connect" + ) as mock_connect: + mock_connect.return_value = True + + connection = await connector.connect_async( + "db.example.com", + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + + # Verify mock_connect fell back to PRIVATE IP "10.0.0.1"! + args, _ = mock_connect.call_args + assert args[0] == "10.0.0.1" + assert connection is True + finally: + # Restore original IPs + fake_client.instance.ip_addrs = original_ips + fake_client.instance.psc_enabled = False + + diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 3699ddc2d..4eb6ce26c 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -249,13 +249,13 @@ async def test_get_preferred_ip_CloudSQLIPTypeError(cache: RefreshAheadCache) -> when missing Public or Private IP addresses. """ instance_metadata: ConnectionInfo = await cache._current - instance_metadata.ip_addrs = {"PRIVATE": "1.1.1.1"} + instance_metadata.ip_addrs = {"PRIVATE": ["1.1.1.1"]} # test error when Public IP is missing with pytest.raises(CloudSQLIPTypeError): instance_metadata.get_preferred_ip(IPTypes.PUBLIC) # test error when Private IP is missing - instance_metadata.ip_addrs = {"PRIMARY": "0.0.0.0"} + instance_metadata.ip_addrs = {"PRIMARY": ["0.0.0.0"]} with pytest.raises(CloudSQLIPTypeError): instance_metadata.get_preferred_ip(IPTypes.PRIVATE) diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index c649e8e58..385fa5925 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import AsyncMock + import dns.message import dns.rdataclass import dns.rdatatype @@ -165,4 +167,114 @@ async def test_DnsResolver_resolve_a_record_empty() -> None: mock_resolve.side_effect = Exception("DNS Error") resolver = DnsResolver() result = await resolver.resolve_a_record("db.example.com") - assert result == [] \ No newline at end of file + assert result == [] + + + +async def test_DnsResolver_with_direct_psc_dns_name() -> None: + """Test DnsResolver resolves direct PSC DNS name using resolve_connect_settings.""" + dns_name = "0123456789ab.fedcba9876543.europe-north2.sql-psc.goog" + real_conn_name = ConnectionName( + "my-project", "europe-north2", "my-instance", dns_name + ) + + mock_client = AsyncMock() + mock_client.resolve_connect_settings.return_value = { + "connectionName": "my-project:europe-north2:my-instance" + } + + resolver = DnsResolver() + resolver.set_client(mock_client) + + result = await resolver.resolve(dns_name) + + assert result == real_conn_name + # Verify mock_client was called with correct trailing dot DNS name! + mock_client.resolve_connect_settings.assert_awaited_once_with( + dns_name + ".", "europe-north2" + ) + + +async def test_DnsResolver_with_cname_resolving_to_psc_dns_name() -> None: + """Test DnsResolver resolves CNAME to PSC DNS and returns proper connection name.""" + dns_name = "db.example.com" + cname_target = "0123456789ab.fedcba9876543.europe-north2.sql-psc.goog" + real_conn_name = ConnectionName( + "my-project", "europe-north2", "my-instance", dns_name + ) + + mock_client = AsyncMock() + mock_client.resolve_connect_settings.return_value = { + "connectionName": "my-project:europe-north2:my-instance" + } + + resolver = DnsResolver() + resolver.set_client(mock_client) + + # Patch resolver CNAME and TXT methods + with patch.object( + resolver, "resolve_cname", AsyncMock(return_value=cname_target) + ), patch.object( + resolver, "resolve_txt", AsyncMock(side_effect=Exception("No TXT")) + ): + + result = await resolver.resolve(dns_name) + + assert result == real_conn_name + mock_client.resolve_connect_settings.assert_awaited_once_with( + cname_target + ".", "europe-north2" + ) + + +async def test_DnsResolver_with_recursive_cnames_to_psc_dns_name() -> None: + """Test DnsResolver resolves recursive CNAMEs to PSC DNS successfully.""" + dns_name = "name1.example.com" + cname2 = "name2.example.com" + cname_target = "0123456789ab.fedcba9876543.europe-north2.sql-psc.goog" + real_conn_name = ConnectionName( + "my-project", "europe-north2", "my-instance", dns_name + ) + + mock_client = AsyncMock() + mock_client.resolve_connect_settings.return_value = { + "connectionName": "my-project:europe-north2:my-instance" + } + + resolver = DnsResolver() + resolver.set_client(mock_client) + + # Mock Lookup CNAME sequence + cname_mock = AsyncMock( + side_effect=lambda name: cname2 if name == dns_name else cname_target + ) + + with patch.object(resolver, "resolve_cname", cname_mock), patch.object( + resolver, "resolve_txt", AsyncMock(side_effect=Exception("No TXT")) + ): + + result = await resolver.resolve(dns_name) + + assert result == real_conn_name + mock_client.resolve_connect_settings.assert_awaited_once_with( + cname_target + ".", "europe-north2" + ) + + +async def test_DnsResolver_cname_loop_throws_error() -> None: + """Test DnsResolver throws error if a CNAME loop is detected.""" + dns_name = "name1.example.com" + cname2 = "name2.example.com" + + resolver = DnsResolver() + + cname_mock = AsyncMock( + side_effect=lambda name: cname2 if name == dns_name else dns_name + ) + + with patch.object(resolver, "resolve_cname", cname_mock), patch.object( + resolver, "resolve_txt", AsyncMock(side_effect=Exception("No TXT")) + ): + + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve(dns_name) + assert "CNAME loop detected" in str(exc_info.value) \ No newline at end of file