diff --git a/packages/bigframes/bigframes/_config/compute_options.py b/packages/bigframes/bigframes/_config/compute_options.py index 027566ae075f..66b500e1d366 100644 --- a/packages/bigframes/bigframes/_config/compute_options.py +++ b/packages/bigframes/bigframes/_config/compute_options.py @@ -168,6 +168,19 @@ class ComputeOptions: int | None: Number of rows, if set. """ + enable_peek_cache: bool = False + """ + If enabled, peeking at a relation will pull a larger local sample (e.g. 10k rows) + and cache it locally. Subsequent compatible operations on the relation will run + locally on the cached sample, enabling fast interactive iteration. + """ + + peek_cache_size: int = 10000 + """ + The size of the local sample to pull and cache when peeking at a relation. + Defaults to 10000. + """ + semantic_ops_confirmation_threshold: Optional[int] = 0 """ Deprecated. diff --git a/packages/bigframes/bigframes/session/bq_caching_executor.py b/packages/bigframes/bigframes/session/bq_caching_executor.py index dede318d8132..973a6be37239 100644 --- a/packages/bigframes/bigframes/session/bq_caching_executor.py +++ b/packages/bigframes/bigframes/session/bq_caching_executor.py @@ -167,6 +167,8 @@ def __init__( labels=dict(labels), ) self._function_manager = function_manager + from bigframes.session.peek_cache import PeekCache + self._peek_cache = PeekCache() def to_sql( self, @@ -209,6 +211,56 @@ async def _execute_async( execution_spec: ex_spec.ExecutionSpec, ) -> executor.ExecuteResult: await self._publisher.publish_async(bigframes.core.events.ExecutionStarted()) + + enable_peek_cache = ( + execution_spec.bigquery_config.enable_peek_cache + if execution_spec.bigquery_config + else False + ) + + if execution_spec.peek is not None and enable_peek_cache: + from bigframes.session.peek_cache import substitute_peek_cached_subplans + rewritten_node = substitute_peek_cached_subplans(array_value.node, self._peek_cache) + if rewritten_node != array_value.node: + rewritten_array_value = bigframes.core.ArrayValue(rewritten_node) + maybe_result = await self._try_execute_semi_executors( + rewritten_array_value, execution_spec + ) + if maybe_result is not None: + return maybe_result + + sample_size = ( + execution_spec.bigquery_config.peek_cache_size + if execution_spec.bigquery_config + else 10000 + ) + actual_sample_size = max(execution_spec.peek, sample_size) + cache_execution_spec = dataclasses.replace(execution_spec, peek=actual_sample_size) + + bq_result = await self._execute_bigquery( + array_value, + cache_execution_spec, + ) + + arrow_table = await asyncio.to_thread(bq_result.batches().to_arrow_table) + managed_table = local_data.ManagedArrowTable.from_pyarrow(arrow_table, bq_result.schema) + self._peek_cache.put(array_value.node, managed_table) + + sliced_table = arrow_table.slice(0, execution_spec.peek) + result: executor.ExecuteResult = executor.LocalExecuteResult( + sliced_table, + bq_result.schema, + execution_metadata=bq_result.execution_metadata, + ) + + await self._publisher.publish_async( + bigframes.core.events.EventEnvelope( + event=bigframes.core.events.ExecutionFinished(result=result), + cell_execution_count=execution_spec.cell_execution_count, + ) + ) + return result + maybe_result = await self._try_execute_semi_executors( array_value, execution_spec ) diff --git a/packages/bigframes/bigframes/session/execution_spec.py b/packages/bigframes/bigframes/session/execution_spec.py index 89de6eec9021..00e12b4ad84e 100644 --- a/packages/bigframes/bigframes/session/execution_spec.py +++ b/packages/bigframes/bigframes/session/execution_spec.py @@ -27,6 +27,8 @@ class BqComputeOptions: enable_multi_query_execution: bool = True maximum_bytes_billed: Optional[int] = None extra_query_labels: tuple[tuple[str, str], ...] = () + enable_peek_cache: bool = False + peek_cache_size: int = 10000 @classmethod def from_compute_options(cls, compute_options: ComputeOptions) -> BqComputeOptions: @@ -34,6 +36,8 @@ def from_compute_options(cls, compute_options: ComputeOptions) -> BqComputeOptio enable_multi_query_execution=compute_options.enable_multi_query_execution, maximum_bytes_billed=compute_options.maximum_bytes_billed, extra_query_labels=tuple(compute_options.extra_query_labels.items()), + enable_peek_cache=compute_options.enable_peek_cache, + peek_cache_size=compute_options.peek_cache_size, ) def push_labels(self, labels: Mapping[str, str]) -> BqComputeOptions: diff --git a/packages/bigframes/bigframes/session/peek_cache.py b/packages/bigframes/bigframes/session/peek_cache.py new file mode 100644 index 000000000000..f3f7ca1157bf --- /dev/null +++ b/packages/bigframes/bigframes/session/peek_cache.py @@ -0,0 +1,96 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import OrderedDict +import threading +from typing import Optional + +from bigframes.core import local_data, nodes + + +class PeekCache: + """ + Thread-safe LRU cache for storing local samples of query relations. + This enables fast iteration on subsequent compatible operations. + """ + + def __init__(self, capacity: int = 100): + self.capacity = capacity + self._cache: OrderedDict[nodes.BigFrameNode, local_data.ManagedArrowTable] = OrderedDict() + self._lock = threading.Lock() + + def get(self, key: nodes.BigFrameNode) -> Optional[local_data.ManagedArrowTable]: + with self._lock: + if key not in self._cache: + return None + # Move to end (most recently used) + self._cache.move_to_end(key) + return self._cache[key] + + def put(self, key: nodes.BigFrameNode, value: local_data.ManagedArrowTable) -> None: + with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + self._cache[key] = value + if len(self._cache) > self.capacity: + self._cache.popitem(last=False) + + def clear(self) -> None: + with self._lock: + self._cache.clear() + + +def substitute_peek_cached_subplans( + root: nodes.BigFrameNode, + peek_cache: PeekCache, +) -> nodes.BigFrameNode: + """ + Recursively replaces subplans in the tree that have a cached local sample + in the peek cache with a ReadLocalNode, provided that all ancestors + of the subplan are compatible with running on a sample. + """ + # Intermediate nodes that preserve the semantic validity of a sample. + # WindowOpNode, AggregateNode, OrderByNode, JoinNode, etc. are excluded + # because evaluating them on a sample breaks semantic contracts. + _COMPATIBLE_ANCESTOR_CLASSES = ( + nodes.SelectionNode, + nodes.ProjectionNode, + nodes.FilterNode, + nodes.PromoteOffsetsNode, + ) + + def traverse(node: nodes.BigFrameNode, ancestors_compatible: bool) -> nodes.BigFrameNode: + if ancestors_compatible: + cached_sample = peek_cache.get(node) + if cached_sample is not None: + # Replace the node with a ReadLocalNode containing the cached sample + scan_list = nodes.ScanList( + tuple(nodes.ScanItem(field.id, field.id.sql) for field in node.fields) + ) + session = node.session if node.session is not None else root.session + return nodes.ReadLocalNode( + local_data_source=cached_sample, + scan_list=scan_list, + session=session, + ) + + # If we didn't replace, recursively transform children + is_current_compatible = isinstance(node, _COMPATIBLE_ANCESTOR_CLASSES) + next_ancestors_compatible = ancestors_compatible and is_current_compatible + + return node.transform_children(lambda child: traverse(child, next_ancestors_compatible)) + + return traverse(root, True) diff --git a/packages/bigframes/tests/unit/session/test_peek_cache.py b/packages/bigframes/tests/unit/session/test_peek_cache.py new file mode 100644 index 000000000000..a92cd03b198c --- /dev/null +++ b/packages/bigframes/tests/unit/session/test_peek_cache.py @@ -0,0 +1,240 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from unittest import mock + +import google.cloud.bigquery +import pyarrow as pa +import pytest + +import bigframes +from bigframes.core import identifiers, local_data, nodes +from bigframes.session import bq_caching_executor, execution_spec, executor +from bigframes.session.peek_cache import PeekCache, substitute_peek_cached_subplans +from bigframes.testing import mocks + + +def test_peek_cache_lru(): + cache = PeekCache(capacity=2) + session = mocks.create_bigquery_session() + + # Create some mock nodes and data sources + table1 = pa.Table.from_pydict({"a": [1, 2]}) + table2 = pa.Table.from_pydict({"b": [3, 4]}) + table3 = pa.Table.from_pydict({"c": [5, 6]}) + + ds1 = local_data.ManagedArrowTable.from_pyarrow(table1) + ds2 = local_data.ManagedArrowTable.from_pyarrow(table2) + ds3 = local_data.ManagedArrowTable.from_pyarrow(table3) + + node1 = nodes.ReadLocalNode(ds1, nodes.ScanList(()), session) + node2 = nodes.ReadLocalNode(ds2, nodes.ScanList(()), session) + node3 = nodes.ReadLocalNode(ds3, nodes.ScanList(()), session) + + cache.put(node1, ds1) + cache.put(node2, ds2) + + # Access node1 to make it most recently used, leaving node2 as least recently used (LRU) + assert cache.get(node1) == ds1 + + # Put node3, which should evict node2 + cache.put(node3, ds3) + + assert cache.get(node2) is None + assert cache.get(node1) == ds1 + assert cache.get(node3) == ds3 + + +def test_substitute_peek_cached_subplans(): + session = mocks.create_bigquery_session() + table = pa.Table.from_pydict({"a": [1, 2]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + + # Create a simple leaf node + leaf = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col_a"), "a"),)), + session=session, + ) + + # Cache the leaf node + cache = PeekCache() + cached_table = pa.Table.from_pydict({"col_a": [100, 200]}) + cached_ds = local_data.ManagedArrowTable.from_pyarrow(cached_table) + cache.put(leaf, cached_ds) + + # Now perform the tree substitution + rewritten = substitute_peek_cached_subplans(leaf, cache) + + # The leaf should be replaced by a new ReadLocalNode containing cached_ds + assert isinstance(rewritten, nodes.ReadLocalNode) + assert rewritten.local_data_source == cached_ds + assert rewritten.session == session + assert len(rewritten.scan_list.items) == 1 + assert rewritten.scan_list.items[0].id == identifiers.ColumnId("col_a") + assert rewritten.scan_list.items[0].source_id == "col_a" + + +def test_executor_peek_cache_integration(): + # Mock all arguments to BigQueryCachingExecutor + bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True) + bqclient.project = "test-project" + storage_manager = mock.Mock() + bqstoragereadclient = mock.Mock() + loader = mock.Mock() + publisher = mock.AsyncMock() + function_manager = mock.Mock() + + executor_obj = bq_caching_executor.BigQueryCachingExecutor( + bqclient=bqclient, + storage_manager=storage_manager, + bqstoragereadclient=bqstoragereadclient, + loader=loader, + publisher=publisher, + function_manager=function_manager, + ) + + table = pa.Table.from_pydict({"col": [1, 2, 3, 4, 5]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + session = mocks.create_bigquery_session() + + node = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col"), "col"),)), + session=session, + ) + arr_value = bigframes.core.ArrayValue(node) + + # Mock _execute_bigquery of the executor to return a mock 3-row table + mock_bq_table = pa.Table.from_pydict({"col": [10, 20, 30]}) + mock_bq_result = executor.LocalExecuteResult(mock_bq_table, arr_value.schema) + + execute_bq_mock = mock.AsyncMock(return_value=mock_bq_result) + executor_obj._execute_bigquery = execute_bq_mock + + # Enable peek cache options + compute_options = bigframes.options.compute + compute_options.enable_peek_cache = True + compute_options.peek_cache_size = 3 + + # Call execute with peek=1 (cache miss path) + spec = execution_spec.ExecutionSpec(peek=1).with_compute_options(compute_options) + result = asyncio.run(executor_obj._execute_async(arr_value, spec)) + + # Verify BQ was called with peek=3 (cache size) + assert execute_bq_mock.call_count == 1 + called_spec = execute_bq_mock.call_args[0][1] + assert called_spec.peek == 3 + + # Verify returned result has exactly 1 row + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 1 + assert result_table["col"].to_pylist() == [10] + + # Verify peek cache has been populated with the 3-row table + cached_entry = executor_obj._peek_cache.get(node) + assert cached_entry is not None + assert cached_entry.to_pyarrow_table()["col"].to_pylist() == [10, 20, 30] + + # Call execute again with peek=2 (cache hit path) + execute_bq_mock.reset_mock() + spec2 = execution_spec.ExecutionSpec(peek=2).with_compute_options(compute_options) + result2 = asyncio.run(executor_obj._execute_async(arr_value, spec2)) + + # Verify BQ was NOT called + assert execute_bq_mock.call_count == 0 + + # Verify returned result has exactly 2 rows + result_table2 = pa.Table.from_batches(result2.batches().arrow_batches) + assert result_table2.num_rows == 2 + assert result_table2["col"].to_pylist() == [10, 20] + + +def test_peek_cache_thread_safety(): + import threading + + cache = PeekCache(capacity=100) + session = mocks.create_bigquery_session() + + # Create dummy nodes and data sources + num_items = 50 + num_threads = 10 + nodes_list = [] + for i in range(num_items): + table = pa.Table.from_pydict({"col": [i]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + node = nodes.ReadLocalNode(ds, nodes.ScanList(()), session) + nodes_list.append((node, ds)) + + def worker(worker_id): + for i in range(100): + node, ds = nodes_list[(worker_id + i) % num_items] + cache.put(node, ds) + cache.get(node) + + threads = [] + for i in range(num_threads): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # The cache should be in a consistent state and not exceed capacity + assert len(cache._cache) <= 100 + + +def test_substitute_peek_cached_subplans_incompatible_ancestors(): + session = mocks.create_bigquery_session() + table = pa.Table.from_pydict({"a": [1, 2]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + + # Leaf node (cached) + leaf = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col_a"), "a"),)), + session=session, + ) + + cache = PeekCache() + cached_table = pa.Table.from_pydict({"col_a": [100, 200]}) + cached_ds = local_data.ManagedArrowTable.from_pyarrow(cached_table) + cache.put(leaf, cached_ds) + + # Scenario A: Path has only compatible nodes: FilterNode -> Leaf + # FilterNode is a compatible ancestor. + plan_compatible = nodes.FilterNode( + child=leaf, + predicate=bigframes.core.expression.ScalarConstantExpression(True), # Dummy expression + ) + + rewritten_compatible = substitute_peek_cached_subplans(plan_compatible, cache) + # The leaf child of FilterNode should be replaced by ReadLocalNode with cached_ds + assert isinstance(rewritten_compatible, nodes.FilterNode) + assert isinstance(rewritten_compatible.child, nodes.ReadLocalNode) + assert rewritten_compatible.child.local_data_source == cached_ds + + # Scenario B: Path has an incompatible node: ReversedNode -> Leaf + # ReversedNode is an incompatible ancestor. + plan_incompatible = nodes.ReversedNode(child=leaf) + + rewritten_incompatible = substitute_peek_cached_subplans(plan_incompatible, cache) + # The leaf child should NOT be replaced by ReadLocalNode + assert isinstance(rewritten_incompatible, nodes.ReversedNode) + assert rewritten_incompatible.child == leaf + assert rewritten_incompatible.child.local_data_source == ds