diff --git a/ldclient/impl/shims/__init__.py b/ldclient/impl/shims/__init__.py new file mode 100644 index 00000000..278157b1 --- /dev/null +++ b/ldclient/impl/shims/__init__.py @@ -0,0 +1,18 @@ +""" +Async-only helpers backing the async data source, event processor, and data +system. + +``shims.aio`` holds hand-maintained async concurrency helpers (``AsyncEvent``, +``AsyncQueue``, ``AsyncTaskRunner``, ``AsyncRepeatingTask``, ``AsyncWorkerPool``, +etc.) that wrap fiddly asyncio plumbing the async shells would otherwise inline +repeatedly. The sync shells use the equivalent stdlib/SDK primitives +(``threading``, ``queue.Queue``, ``RepeatingTask``, ``FixedThreadPool``, +urllib3) directly, so there is no sync twin for these. + +``shims.aio_transport`` wraps an aiohttp ``ClientSession`` behind a +``TransportResponse`` (see ``shims.transport_types``) so async data-source +callers can inspect a response after the request context has closed and so the +SSL/session setup is written once. The sync side talks to urllib3 directly. + +These are covered by ``ldclient/testing/impl/test_shims.py``. +""" diff --git a/ldclient/impl/shims/aio.py b/ldclient/impl/shims/aio.py new file mode 100644 index 00000000..61722f8e --- /dev/null +++ b/ldclient/impl/shims/aio.py @@ -0,0 +1,367 @@ +""" +Async concurrency helpers used by the async data source, event processor, and +data system shells. Each wraps a piece of fiddly asyncio plumbing (timeout-aware +waits, queue exception normalization, an interval-from-start repeating task, a +bounded task pool) that the shells would otherwise inline repeatedly. The sync +shells use the equivalent stdlib/SDK primitives (``threading.Event``/``Lock``, +``queue.Queue``, ``ReadWriteLock``, ``RepeatingTask``, ``FixedThreadPool``) +directly, so these helpers have no sync twin. +""" + +import asyncio +import inspect +import time +from contextlib import contextmanager +from queue import Empty as QueueEmpty # noqa: F401 (shared timeout exception) +from queue import Full as QueueFull # noqa: F401 (shared capacity exception) +from typing import Any, Callable, Optional, Set + +from ldclient.impl.util import log + + +class AsyncEvent: + """Wraps ``asyncio.Event``, adding a ``wait(timeout)`` that returns False on + timeout instead of raising, to match ``threading.Event.wait``.""" + + def __init__(self): + self._event = asyncio.Event() + + def set(self) -> None: + self._event.set() + + def clear(self) -> None: + self._event.clear() + + def is_set(self) -> bool: + return self._event.is_set() + + async def wait(self, timeout: Optional[float] = None) -> bool: + if timeout is None: + await self._event.wait() + return True + try: + await asyncio.wait_for(self._event.wait(), timeout) + return True + except asyncio.TimeoutError: + return False + + +class AsyncLock: + """Wraps ``asyncio.Lock`` as an async context manager.""" + + def __init__(self): + self._lock = asyncio.Lock() + + async def __aenter__(self): + await self._lock.acquire() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self._lock.release() + return False + + def locked(self) -> bool: + return self._lock.locked() + + +class AsyncRWLock: + """No-op read/write lock with the surface of + ``ldclient.impl.rwlock.ReadWriteLock``. Code that holds the lock without + awaiting in between cannot be preempted within a single event loop, so no + real locking is needed; the sync twin is the real ``ReadWriteLock``.""" + + @contextmanager + def read(self): + yield self + + @contextmanager + def write(self): + yield self + + +class AsyncQueue: + """Wraps ``asyncio.Queue`` with a ``get(timeout)`` that raises the shared + ``QueueEmpty`` on timeout, and a ``put_nowait`` that raises the shared + ``QueueFull`` when a bounded queue is at capacity.""" + + def __init__(self, maxsize: int = 0): + self._queue: asyncio.Queue = asyncio.Queue(maxsize=maxsize) + + async def put(self, item: Any) -> None: + await self._queue.put(item) + + def put_nowait(self, item: Any) -> None: + try: + self._queue.put_nowait(item) + except asyncio.QueueFull: + raise QueueFull() from None + + async def get(self, timeout: Optional[float] = None, block: bool = True) -> Any: + if not block: + try: + return self._queue.get_nowait() + except asyncio.QueueEmpty: + raise QueueEmpty() from None + if timeout is None: + return await self._queue.get() + try: + return await asyncio.wait_for(self._queue.get(), timeout) + except asyncio.TimeoutError: + raise QueueEmpty() + + def empty(self) -> bool: + return self._queue.empty() + + +# The handle type returned by spawn_handle. +TaskHandle = asyncio.Task + + +def _log_task_exception(task: asyncio.Task) -> None: + if not task.cancelled() and task.exception() is not None: + log.error("Unhandled exception in background task", exc_info=task.exception()) + + +def spawn_handle(name: str, fn: Callable) -> TaskHandle: + """Starts ``fn()`` as a background task and returns the task handle. + Unhandled exceptions are logged.""" + task = asyncio.ensure_future(fn()) + try: + task.set_name(name) + except AttributeError: + pass + task.add_done_callback(_log_task_exception) + return task + + +async def join_handle(handle: TaskHandle, timeout: float) -> None: + """Waits up to ``timeout`` seconds for a spawned task to finish, mirroring + ``Thread.join(timeout)``: the task's exception (if any) is not re-raised. + On timeout the task is cancelled so it does not leak.""" + try: + await asyncio.wait_for(asyncio.shield(handle), timeout) + except asyncio.TimeoutError: + handle.cancel() + except asyncio.CancelledError: + # The awaited task was cancelled elsewhere; do not treat that as a + # cancellation of the caller. + if handle.cancelled(): + return + raise + except Exception: + pass + + +async def resolve(value: Any) -> Any: + """Awaits and returns ``value`` if it is awaitable, returns it directly + otherwise, letting shared code consume results from duck-typed + sync-or-async components.""" + if inspect.isawaitable(value): + return await value + return value + + +_STOP = object() + + +class _SyncGenAdapter: + """Async iterator over a synchronous generator; each ``next()`` call runs + on an executor thread so it does not block the event loop.""" + + def __init__(self, gen): + self._gen = gen + + def __aiter__(self): + return self + + async def __anext__(self): + def _next(): + try: + return next(self._gen) + except StopIteration: + return _STOP + + loop = asyncio.get_running_loop() + value = await loop.run_in_executor(None, _next) + if value is _STOP: + raise StopAsyncIteration + return value + + +def iterate(gen: Any) -> Any: + """Adapts a sync-or-async generator for async iteration. Async generators + are returned unchanged; synchronous generators are driven on executor + threads.""" + if hasattr(gen, '__aiter__'): + return gen + return _SyncGenAdapter(gen) + + +class AsyncCallbackScheduler: + """Bridges sync notification paths to async callbacks: ``call`` schedules + a coroutine callback onto the event loop captured at construction time, + logging any unhandled exception. Safe to invoke from any thread.""" + + def __init__(self): + self._loop = asyncio.get_running_loop() + + def call(self, fn: Callable, *args) -> None: + future = asyncio.run_coroutine_threadsafe(fn(*args), self._loop) + future.add_done_callback(self._log_exception) + + @staticmethod + def _log_exception(future) -> None: + if not future.cancelled() and future.exception() is not None: + log.error("Unhandled exception in scheduled callback", exc_info=future.exception()) + + +class AsyncTaskRunner: + """Spawns named background tasks and stops them all on demand.""" + + def __init__(self): + self._tasks: Set[asyncio.Task] = set() + self._stopped = False + + def spawn(self, name: str, fn: Callable) -> asyncio.Task: + """Starts ``fn()`` as a background task and returns the task handle. + Unhandled exceptions are logged.""" + task = asyncio.ensure_future(fn()) + try: + task.set_name(name) + except AttributeError: + pass + task.add_done_callback(self._on_done) + self._tasks.add(task) + return task + + def _on_done(self, task: asyncio.Task) -> None: + self._tasks.discard(task) + if not task.cancelled() and task.exception() is not None: + log.error("Unhandled exception in background task", exc_info=task.exception()) + + def is_stopped(self) -> bool: + return self._stopped + + async def stop_all(self, timeout: float = 1) -> None: + """Cancels all running tasks and waits for them to finish, logging a + warning for any that do not terminate within ``timeout`` seconds.""" + self._stopped = True + tasks = list(self._tasks) + if not tasks: + return + for task in tasks: + task.cancel() + _, pending = await asyncio.wait(tasks, timeout=timeout) + for task in pending: + log.warning("Task %s did not terminate in time", task.get_name()) + + +class AsyncRepeatingTask: + """Calls a callback repeatedly at fixed intervals on a background task. + Mirrors the semantics of ``ldclient.impl.repeating_task.RepeatingTask``: + the interval is measured from the start of each invocation, exceptions + from the callback are logged, and ``stop()`` prevents any further + invocations but cannot be undone.""" + + def __init__(self, label: str, interval: float, initial_delay: float, callable: Callable): + self.__label = label + self.__interval = interval + self.__initial_delay = initial_delay + self.__action = callable + self.__stop = AsyncEvent() + self.__task: Optional[asyncio.Task] = None + + def start(self): + """Starts the background task. Like a thread, the task can only be + started once.""" + if self.__task is not None: + raise RuntimeError("tasks can only be started once") + self.__task = asyncio.ensure_future(self._run()) + try: + self.__task.set_name(f"{self.__label}.repeating") + except AttributeError: + pass + + def stop(self): + """Tells the background task to stop. It cannot be restarted after this.""" + self.__stop.set() + task = self.__task + # When stop() is called from within the action itself, let the loop + # exit via the stop event rather than cancelling the current task. + if task is not None and task is not asyncio.current_task(): + task.cancel() + + async def _run(self): + try: + if self.__initial_delay > 0: + if await self.__stop.wait(self.__initial_delay): + return + stopped = self.__stop.is_set() + while not stopped: + next_time = time.time() + self.__interval + try: + result = self.__action() + if inspect.isawaitable(result): + await result + except asyncio.CancelledError: + raise + except Exception as e: + log.exception("Unexpected exception on worker task: %s" % e) + delay = next_time - time.time() + if delay > 0: + stopped = await self.__stop.wait(delay) + else: + # Yield to the event loop between back-to-back invocations + await asyncio.sleep(0) + stopped = self.__stop.is_set() + except asyncio.CancelledError: + pass + + +class AsyncWorkerPool: + """A fixed-size pool of concurrent tasks that rejects jobs when its limit + is reached. Matches the contract of + ``ldclient.impl.fixed_thread_pool.FixedThreadPool``.""" + + def __init__(self, size: int, name: str): + self._size = size + self._name = name + self._busy: Set[asyncio.Task] = set() + self._event = AsyncEvent() + self._stopped = False + + def execute(self, jobFn: Callable) -> bool: + """Schedules a job for execution if the pool is not already at its + limit, and returns True if successful; returns False if all workers + are busy.""" + if self._stopped or len(self._busy) >= self._size: + return False + task = asyncio.ensure_future(self._run_job(jobFn)) + self._busy.add(task) + return True + + async def _run_job(self, jobFn: Callable) -> None: + try: + result = jobFn() + if inspect.isawaitable(result): + await result + except Exception: + log.warning('Unhandled exception in worker thread', exc_info=True) + finally: + task = asyncio.current_task() + if task is not None: + self._busy.discard(task) + self._event.set() + + async def wait(self) -> None: + """Waits until all currently busy workers have completed their jobs.""" + while len(self._busy) > 0: + self._event.clear() + if len(self._busy) == 0: + return + await self._event.wait() + + def stop(self) -> None: + """Tells the pool to reject any further jobs; active jobs run to + completion.""" + self._stopped = True diff --git a/ldclient/impl/shims/aio_transport.py b/ldclient/impl/shims/aio_transport.py new file mode 100644 index 00000000..e8983379 --- /dev/null +++ b/ldclient/impl/shims/aio_transport.py @@ -0,0 +1,154 @@ +""" +Async HTTP transport shim, wrapping an aiohttp ``ClientSession``. The sync +twin in ``ldclient.impl.shims.sync_transport`` wraps a urllib3 +``PoolManager`` with an identical method surface. +""" + +import ssl +from typing import Optional, Union + +import aiohttp +import certifi +from ld_eventsource.async_client import AsyncSSEClient +from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy +from ld_eventsource.config.error_strategy import ErrorStrategy +from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy + +from ldclient.impl.http import _base_headers +from ldclient.impl.shims.transport_types import TransportResponse +from ldclient.impl.util import log + +# Allows up to 5 minutes to elapse without any data sent across the stream. +# Heartbeats sent as comments will keep this from triggering. +STREAM_READ_TIMEOUT = 5 * 60 + +MAX_RETRY_DELAY = 30 +BACKOFF_RESET_INTERVAL = 60 +JITTER_RATIO = 0.5 + + +def make_client_session(config, http_options=None) -> aiohttp.ClientSession: + """Creates an ``aiohttp.ClientSession`` configured from the SDK config's + HTTP options (CA certs, client certs, SSL verification, proxy trust). + ``http_options`` overrides the config's HTTP options when given.""" + http_config = http_options if http_options is not None else config.http + ssl_ctx = ssl.create_default_context(cafile=http_config.ca_certs or certifi.where()) + if http_config.cert_file: + ssl_ctx.load_cert_chain(http_config.cert_file) + if http_config.disable_ssl_verification: + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + log.warning("TLS verification disabled") + + connector = aiohttp.TCPConnector(ssl=ssl_ctx, limit_per_host=10) + return aiohttp.ClientSession( + connector=connector, + trust_env=(http_config.http_proxy is None), + ) + + +class AsyncHTTPTransport: + """Performs HTTP requests over an aiohttp ``ClientSession``. + + If no session is supplied, one is created lazily from the config on first + use and owned (closed) by the transport; a supplied session remains owned + by the caller. ``http_options`` overrides the config's HTTP options for + lazy session construction and default request timeouts. + """ + + def __init__( + self, + config, + client: Optional[aiohttp.ClientSession] = None, + http_options=None, + ): + self._config = config + self._http_options = http_options if http_options is not None else config.http + self._client = client + self._owns_client = client is None + self._proxy = self._http_options.http_proxy or None + + async def request( + self, + method: str, + uri: str, + headers: Optional[dict] = None, + body: Optional[Union[bytes, str]] = None, + connect_timeout: Optional[float] = None, + read_timeout: Optional[float] = None, + ) -> TransportResponse: + """Performs a request.""" + if self._client is None: + self._client = make_client_session(self._config, self._http_options) + timeout = aiohttp.ClientTimeout( + connect=connect_timeout if connect_timeout is not None else self._http_options.connect_timeout, + sock_read=read_timeout if read_timeout is not None else self._http_options.read_timeout, + ) + async with self._client.request( + method, + uri, + headers=headers, + data=body, + timeout=timeout, + proxy=self._proxy, + ) as response: + text = await response.text(encoding='UTF-8', errors='replace') + return TransportResponse(response.status, response.headers, text) + + async def close(self) -> None: + """Closes the underlying session if this transport created it.""" + if self._owns_client and self._client is not None: + await self._client.close() + self._client = None + + +class AsyncSSEFactory: + """Creates configured ``AsyncSSEClient`` instances for streaming connections. + + A supplied aiohttp session remains owned by the caller; the SSE client + never closes it. Callers are expected to supply a session built from the + SDK's HTTP options (via ``make_client_session``) so the streaming + connection uses the configured certs, SSL settings, and proxy trust. + ``http_options`` overrides the config's HTTP options for connection + timeouts and proxy settings. + """ + + def __init__(self, config, session: Optional[aiohttp.ClientSession] = None, proxy: Optional[str] = None, http_options=None): + self._config = config + self._session = session + self._http_options = http_options if http_options is not None else config.http + self._proxy = proxy if proxy is not None else (self._http_options.http_proxy or None) + + def create(self, url: str, initial_retry_delay: float, query_params=None) -> AsyncSSEClient: + """Builds an SSE client for the given stream URL. Headers, timeouts, + proxy settings, and the retry/backoff policy come from the SDK config. + ``query_params`` is an optional zero-argument callable evaluated on + each (re)connect to produce additional query string parameters.""" + base_headers = _base_headers(self._config) + aiohttp_request_options: dict = { + "timeout": aiohttp.ClientTimeout( + total=None, + connect=self._http_options.connect_timeout, + sock_read=STREAM_READ_TIMEOUT, + ) + } + if self._proxy: + aiohttp_request_options["proxy"] = self._proxy + return AsyncSSEClient( + connect=AsyncConnectStrategy.http( + url=url, + headers=base_headers, + session=self._session, + aiohttp_request_options=aiohttp_request_options, + query_params=query_params, + ), + error_strategy=ErrorStrategy.always_continue(), # we'll make error-handling decisions when we see a Fault + initial_retry_delay=initial_retry_delay, + retry_delay_strategy=RetryDelayStrategy.default( + max_delay=MAX_RETRY_DELAY, + backoff_multiplier=2, + jitter_multiplier=JITTER_RATIO, + ), + retry_delay_reset_threshold=BACKOFF_RESET_INTERVAL, + logger=log, + ) diff --git a/ldclient/impl/shims/transport_types.py b/ldclient/impl/shims/transport_types.py new file mode 100644 index 00000000..e39d243c --- /dev/null +++ b/ldclient/impl/shims/transport_types.py @@ -0,0 +1,19 @@ +""" +Shared types for the async/sync transport shims. +""" + +from typing import Mapping + + +class TransportResponse: + """A minimal uniform HTTP response: status code, headers, and decoded body. + + ``headers`` is whatever case-insensitive mapping the underlying HTTP + library produced, so lookups like ``headers.get('ETag')`` work regardless + of the casing the server used. + """ + + def __init__(self, status: int, headers: Mapping[str, str], body: str): + self.status = status + self.headers = headers + self.body = body diff --git a/ldclient/testing/impl/test_shims.py b/ldclient/testing/impl/test_shims.py new file mode 100644 index 00000000..4964129a --- /dev/null +++ b/ldclient/testing/impl/test_shims.py @@ -0,0 +1,392 @@ +""" +Tests for the async concurrency and HTTP transport helpers in +``ldclient/impl/shims/``. ``shims.aio`` holds the async concurrency +primitives; ``shims.aio_transport`` wraps an aiohttp session for HTTP and SSE. +""" + +import asyncio +import ssl +import threading +import time + +import aiohttp +import pytest +from ld_eventsource.async_client import AsyncSSEClient + +from ldclient.config import Config, HTTPConfig +from ldclient.impl.shims import aio +from ldclient.impl.shims.aio_transport import ( + AsyncHTTPTransport, + AsyncSSEFactory, + make_client_session +) + + +async def _async_wait_until(predicate, timeout=2.0): + deadline = time.time() + timeout + while not predicate(): + assert time.time() < deadline, "timed out waiting for condition" + await asyncio.sleep(0.01) + + +# --------------------------------------------------------------------------- +# aio.AsyncEvent +# --------------------------------------------------------------------------- + +class TestAsyncEvent: + @pytest.mark.asyncio + async def test_set_clear_is_set(self): + event = aio.AsyncEvent() + assert not event.is_set() + event.set() + assert event.is_set() + assert await event.wait(1) is True + event.clear() + assert not event.is_set() + + @pytest.mark.asyncio + async def test_wait_timeout_returns_false(self): + event = aio.AsyncEvent() + start = time.time() + assert await event.wait(0.05) is False + assert time.time() - start >= 0.04 + + @pytest.mark.asyncio + async def test_wait_wakes_when_set(self): + event = aio.AsyncEvent() + + async def setter(): + await asyncio.sleep(0.02) + event.set() + + task = asyncio.ensure_future(setter()) + assert await event.wait(2) is True + await task + + +# --------------------------------------------------------------------------- +# aio.AsyncLock +# --------------------------------------------------------------------------- + +class TestAsyncLock: + @pytest.mark.asyncio + async def test_context_manager_tracks_locked_state(self): + lock = aio.AsyncLock() + assert not lock.locked() + async with lock: + assert lock.locked() + assert not lock.locked() + + @pytest.mark.asyncio + async def test_mutual_exclusion(self): + lock = aio.AsyncLock() + counter = {'value': 0, 'concurrent': 0, 'max_concurrent': 0} + + async def work(): + async with lock: + counter['concurrent'] += 1 + counter['max_concurrent'] = max(counter['max_concurrent'], counter['concurrent']) + await asyncio.sleep(0.01) + counter['value'] += 1 + counter['concurrent'] -= 1 + + await asyncio.gather(*(work() for _ in range(5))) + assert counter['value'] == 5 + assert counter['max_concurrent'] == 1 + + +# --------------------------------------------------------------------------- +# aio.AsyncRepeatingTask +# --------------------------------------------------------------------------- + +class TestAsyncRepeatingTask: + @pytest.mark.asyncio + async def test_fires_repeatedly_then_stops_cleanly(self): + counts = {'n': 0} + + async def action(): + counts['n'] += 1 + + task = aio.AsyncRepeatingTask("test.repeating", 0.01, 0, action) + task.start() + await _async_wait_until(lambda: counts['n'] >= 3) + task.stop() + await asyncio.sleep(0.05) + snapshot = counts['n'] + await asyncio.sleep(0.05) + assert counts['n'] == snapshot + + @pytest.mark.asyncio + async def test_initial_delay_respected(self): + counts = {'n': 0} + + async def action(): + counts['n'] += 1 + + task = aio.AsyncRepeatingTask("test.repeating", 0.01, 0.1, action) + task.start() + await asyncio.sleep(0.03) + assert counts['n'] == 0 + task.stop() + + @pytest.mark.asyncio + async def test_continues_after_action_exception(self): + counts = {'n': 0} + + async def action(): + counts['n'] += 1 + raise RuntimeError("boom") + + task = aio.AsyncRepeatingTask("test.repeating", 0.01, 0, action) + task.start() + await _async_wait_until(lambda: counts['n'] >= 2) + task.stop() + + @pytest.mark.asyncio + async def test_stop_from_within_action(self): + counts = {'n': 0} + holder = {} + + async def action(): + counts['n'] += 1 + holder['task'].stop() + + holder['task'] = aio.AsyncRepeatingTask("test.repeating", 0.01, 0, action) + holder['task'].start() + await asyncio.sleep(0.1) + assert counts['n'] == 1 + + @pytest.mark.asyncio + async def test_second_start_raises(self): + async def action(): + pass + + task = aio.AsyncRepeatingTask("test.repeating", 0.01, 0, action) + task.start() + with pytest.raises(RuntimeError): + task.start() + task.stop() + + +# --------------------------------------------------------------------------- +# aio.AsyncCallbackScheduler +# --------------------------------------------------------------------------- + +class TestAsyncCallbackScheduler: + @pytest.mark.asyncio + async def test_call_schedules_coroutine_with_args(self): + scheduler = aio.AsyncCallbackScheduler() + received = [] + + async def cb(value): + received.append(value) + + scheduler.call(cb, 'value') + await _async_wait_until(lambda: received == ['value']) + + @pytest.mark.asyncio + async def test_call_works_from_worker_thread(self): + scheduler = aio.AsyncCallbackScheduler() + received = [] + + async def cb(value): + received.append(value) + + thread = threading.Thread(target=lambda: scheduler.call(cb, 'threaded')) + thread.start() + thread.join() + await _async_wait_until(lambda: received == ['threaded']) + + @pytest.mark.asyncio + async def test_call_swallows_callback_exception(self): + scheduler = aio.AsyncCallbackScheduler() + completed = aio.AsyncEvent() + + async def boom(): + completed.set() + raise ValueError("boom") + + scheduler.call(boom) # must not raise + assert await completed.wait(2) + await asyncio.sleep(0.05) # let the done callback run + + +# --------------------------------------------------------------------------- +# aio.AsyncTaskRunner +# --------------------------------------------------------------------------- + +class TestAsyncTaskRunner: + @pytest.mark.asyncio + async def test_spawn_runs_function(self): + runner = aio.AsyncTaskRunner() + done = aio.AsyncEvent() + + async def fn(): + done.set() + + runner.spawn("test.task", fn) + assert await done.wait(2) + await runner.stop_all() + assert runner.is_stopped() + + @pytest.mark.asyncio + async def test_stop_all_cancels_running_tasks(self): + runner = aio.AsyncTaskRunner() + started = aio.AsyncEvent() + + async def forever(): + started.set() + await asyncio.sleep(60) + + task = runner.spawn("test.forever", forever) + await started.wait(2) + await runner.stop_all() + assert task.done() + + @pytest.mark.asyncio + async def test_stop_all_honors_timeout_for_stubborn_task(self): + runner = aio.AsyncTaskRunner() + started = aio.AsyncEvent() + give_up = asyncio.Event() + + async def stubborn(): + while not give_up.is_set(): + try: + started.set() + await asyncio.sleep(60) + except asyncio.CancelledError: + continue # refuse to die until give_up is set + + task = runner.spawn("test.stubborn", stubborn) + await started.wait(2) + start = time.time() + await runner.stop_all(timeout=0.1) + assert time.time() - start < 2 + assert not task.done() # the stubborn task outlived stop_all + give_up.set() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +# --------------------------------------------------------------------------- +# aio.resolve +# --------------------------------------------------------------------------- + +class TestResolve: + @pytest.mark.asyncio + async def test_resolve_plain_value(self): + assert await aio.resolve(42) == 42 + + @pytest.mark.asyncio + async def test_resolve_awaitable(self): + async def coro(): + return 42 + + assert await aio.resolve(coro()) == 42 + + +# --------------------------------------------------------------------------- +# aio_transport.make_client_session +# --------------------------------------------------------------------------- + +class TestMakeClientSession: + @pytest.mark.asyncio + async def test_default_verifies_tls_and_trusts_env(self): + config = Config(sdk_key='sdk-key') + session = make_client_session(config) + try: + ctx = session.connector._ssl + assert isinstance(ctx, ssl.SSLContext) + assert ctx.check_hostname is True + assert ctx.verify_mode == ssl.CERT_REQUIRED + assert session.trust_env is True + finally: + await session.close() + + @pytest.mark.asyncio + async def test_disable_ssl_verification_relaxes_context(self): + config = Config(sdk_key='sdk-key', http=HTTPConfig(disable_ssl_verification=True)) + session = make_client_session(config) + try: + ctx = session.connector._ssl + assert ctx.check_hostname is False + assert ctx.verify_mode == ssl.CERT_NONE + finally: + await session.close() + + @pytest.mark.asyncio + async def test_proxy_disables_env_trust(self): + config = Config(sdk_key='sdk-key', http=HTTPConfig(http_proxy='http://my-proxy:1234')) + session = make_client_session(config) + try: + assert session.trust_env is False + finally: + await session.close() + + @pytest.mark.asyncio + async def test_http_options_override_config(self): + config = Config(sdk_key='sdk-key') + override = HTTPConfig(disable_ssl_verification=True) + session = make_client_session(config, http_options=override) + try: + ctx = session.connector._ssl + assert ctx.verify_mode == ssl.CERT_NONE + finally: + await session.close() + + +# --------------------------------------------------------------------------- +# aio_transport.AsyncHTTPTransport ownership +# --------------------------------------------------------------------------- + +class TestAsyncHTTPTransportOwnership: + @pytest.mark.asyncio + async def test_closes_session_it_created(self): + config = Config(sdk_key='sdk-key') + transport = AsyncHTTPTransport(config) + # Force lazy session creation without making a network request. + transport._client = make_client_session(config) + transport._owns_client = True + session = transport._client + await transport.close() + assert session.closed is True + assert transport._client is None + + @pytest.mark.asyncio + async def test_leaves_caller_supplied_session_open(self): + config = Config(sdk_key='sdk-key') + session = make_client_session(config) + try: + transport = AsyncHTTPTransport(config, client=session) + await transport.close() + assert session.closed is False + finally: + await session.close() + + +# --------------------------------------------------------------------------- +# aio_transport.AsyncSSEFactory +# --------------------------------------------------------------------------- + +class TestAsyncSSEFactory: + @pytest.mark.asyncio + async def test_create_with_supplied_session_returns_client_and_leaves_it_open(self): + config = Config(sdk_key='sdk-key') + session = make_client_session(config) + try: + factory = AsyncSSEFactory(config, session=session) + client = factory.create('http://localhost:1/stream', 1.0) + assert isinstance(client, AsyncSSEClient) + assert session.closed is False + finally: + await session.close() + + @pytest.mark.asyncio + async def test_create_without_session_returns_client(self): + config = Config(sdk_key='sdk-key') + factory = AsyncSSEFactory(config) + client = factory.create('http://localhost:1/stream', 1.0) + assert isinstance(client, AsyncSSEClient) diff --git a/pyproject.toml b/pyproject.toml index 8672bb94..46b629bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ Repository = "https://github.com/launchdarkly/python-server-sdk" Documentation = "https://launchdarkly-python-sdk.readthedocs.io/en/latest/" [project.optional-dependencies] +async = ["aiohttp>=3.9,<4"] redis = ["redis>=2.10.5"] consul = ["python-consul>=1.0.1"] dynamodb = ["boto3>=1.9.71"] @@ -45,6 +46,7 @@ test-filesource = ["pyyaml>=5.3.1", "watchdog>=3.0.0"] dev = [ "mock>=2.0.0", "pytest>=8.0.0,<10", + "pytest-asyncio>=0.23", "redis>=2.10.5,<9.0.0", "boto3>=1.9.71,<2.0.0", "coverage>=4.4", @@ -58,6 +60,7 @@ dev = [ "types-mock>=5.0", "types-redis>=4.0", "types-setuptools>=68.0", + "aiohttp>=3.8.0", ] contract-tests = [ "Flask<4", @@ -85,6 +88,7 @@ multi_line_output = 3 [tool.pytest.ini_options] addopts = ["-ra"] +asyncio_mode = "strict" [build-system]