diff --git a/src/groundlight/__init__.py b/src/groundlight/__init__.py index 805fdd33..baf66fd3 100644 --- a/src/groundlight/__init__.py +++ b/src/groundlight/__init__.py @@ -7,7 +7,7 @@ # Imports from our code from .client import Groundlight -from .client import GroundlightClientError, ApiTokenError, EdgeNotAvailableError, NotFoundError +from .client import GroundlightClientError, ApiTokenError, EdgeNotAvailableError, NotFoundError, VLMVerificationResult from .experimental_api import ExperimentalApi from .binary_labels import Label from .version import get_version diff --git a/src/groundlight/client.py b/src/groundlight/client.py index edcb8771..4107f530 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -3,10 +3,12 @@ import os import time import warnings +from dataclasses import dataclass from functools import partial from io import BufferedReader, BytesIO from typing import Any, Callable, List, Optional, Tuple, Union +import requests from groundlight_openapi_client import Configuration from groundlight_openapi_client.api.detector_groups_api import DetectorGroupsApi from groundlight_openapi_client.api.detectors_api import DetectorsApi @@ -73,6 +75,22 @@ class EdgeNotAvailableError(GroundlightClientError): """Raised when an edge-only method is called against a non-edge endpoint.""" +@dataclass +class VLMVerificationResult: + """Result of a VLM-based alert verification via the Groundlight cloud API.""" + + id: str + query: str + model_id: str + verdict: str # "YES" | "NO" | "UNSURE" + confidence: float # 0.0–1.0 + reasoning: str + created_at: str + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + total_cost_usd: Optional[float] = None + + class Groundlight: # pylint: disable=too-many-instance-attributes,too-many-public-methods """ Client for accessing the Groundlight cloud service. Provides methods to create visual detectors, @@ -1089,6 +1107,130 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments inspection_id=inspection_id, ) + def ask_vlm( + self, + media: Union[ + "np.ndarray", + List["np.ndarray"], + str, + bytes, + "Image.Image", + BytesIO, + BufferedReader, + ], + query: str, + model_id: Optional[str] = None, + timeout: float = 15.0, + ) -> VLMVerificationResult: + """Verify one or more images against a natural-language query using a cloud VLM. + + Calls the Groundlight ``POST /v1/vlm-verifications`` endpoint. The VLM runs in the + Groundlight cloud (AWS Bedrock) — no local inference. + + The server makes no assumptions about what the images are — your ``query`` should + describe them. Images are presented to the model labeled ``Image 1``, ``Image 2``, + ... in the order given, so the query can refer to them. + + **Example usage**:: + + gl = Groundlight() + + # Single image + result = gl.ask_vlm(frame, query="Is there a fire in this image?") + if result.verdict == "YES": + emit_alert() + + # Full frame + cropped ROI — describe each in the query + result = gl.ask_vlm( + media=[full_frame, roi_crop], + query="Image 1 is the full camera frame; image 2 is the cropped region " + "a detector flagged. Is there really a fire?", + ) + print(result.confidence, result.reasoning) + + :param media: One image or a list of up to 8 images. Accepted formats per image: + + - filename (string) of a JPEG/PNG file + - raw bytes or BytesIO / BufferedReader + - numpy array (H, W, 3) in BGR order (OpenCV convention) + - PIL Image + + :param query: Natural-language prompt describing the media and what to verify, + e.g. ``"Is there a fire visible in the image? Reason step by step."`` + :param model_id: Friendly alias of the VLM to use. The server is the source + of truth; passing an unrecognised alias returns HTTP 400. Currently + supported aliases: + + - ``"gpt-5.4"`` — OpenAI GPT-5.4 via Bedrock Responses API (default) + - ``"claude-sonnet-4.5"`` — Anthropic Claude Sonnet 4.5 + - ``"claude-haiku-3"`` — Anthropic Claude Haiku 3 + - ``"nova-pro"`` — Amazon Nova Pro + - ``"nova-lite"`` — Amazon Nova Lite + - ``"llama3.2-90b"`` — Meta Llama 3.2 90B + - ``"llama3.2-11b"`` — Meta Llama 3.2 11B + + Omit to use the server-configured default (currently ``"gpt-5.4"``). + :param timeout: Request timeout in seconds (default 15 s). + + :return: :class:`VLMVerificationResult` with ``verdict`` (``"YES"`` / ``"NO"`` / + ``"UNSURE"``), ``confidence``, ``reasoning``, and token cost fields. + :raises ValueError: If more than 8 media items are supplied. + :raises requests.HTTPError: On non-2xx response (400 for invalid model alias + or undecodable image bytes; 502 if the upstream VLM is unavailable). + """ + # Normalise: single image → list + if not isinstance(media, list): + media = [media] + if len(media) > 8: + raise ValueError("ask_vlm supports at most 8 media items.") + + # Convert each image to JPEG bytes via the existing SDK utility + media_files: list[tuple[str, tuple[str, bytes, str]]] = [] + for i, img in enumerate(media): + stream = parse_supported_image_types(img) + jpeg_bytes = stream.read() + media_files.append(("media", (f"image_{i}.jpg", jpeg_bytes, "image/jpeg"))) + + # query and model_id are sent as multipart form fields (not query-string + # params): the prompt can be long and must not end up in URLs or access logs. + form_data: dict[str, str] = {"query": query} + if model_id: + form_data["model_id"] = model_id + + headers = { + "x-api-token": self.api_client.configuration.api_key["ApiToken"], + "X-Request-Id": f"ask_vlm_{int(time.time() * 1000)}", + "x-sdk-language": "python", + } + + url = f"{self.endpoint}/v1/vlm-verifications" + + resp = requests.post( + url, + data=form_data, + files=media_files, + headers=headers, + timeout=timeout, + verify=self.api_client.configuration.verify_ssl, + ) + resp.raise_for_status() + data = resp.json() + + result_block = data.get("result", {}) + cost_block = data.get("cost", {}) + return VLMVerificationResult( + id=data.get("id", ""), + query=data.get("query", query), + model_id=data.get("model_id", model_id or ""), + verdict=result_block.get("verdict", "UNSURE"), + confidence=float(result_block.get("confidence", 0.0)), + reasoning=result_block.get("reasoning", ""), + created_at=data.get("created_at", ""), + input_tokens=cost_block.get("input_tokens"), + output_tokens=cost_block.get("output_tokens"), + total_cost_usd=cost_block.get("total_cost_usd"), + ) + def wait_for_confident_result( self, image_query: Union[ImageQuery, str], diff --git a/test/unit/test_ask_vlm.py b/test/unit/test_ask_vlm.py new file mode 100644 index 00000000..d6e829d7 --- /dev/null +++ b/test/unit/test_ask_vlm.py @@ -0,0 +1,90 @@ +"""Unit tests for Groundlight.ask_vlm — all HTTP mocked, no live server needed.""" + +from unittest import mock +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from groundlight import Groundlight, VLMVerificationResult + + +@pytest.fixture(name="gl") +def groundlight_fixture(monkeypatch) -> Groundlight: + monkeypatch.setenv("GROUNDLIGHT_API_TOKEN", "api_fake_test_token") + with patch.object(Groundlight, "_verify_connectivity", return_value=None): + return Groundlight(endpoint="http://test-server/device-api/") + + +def _mock_response(verdict="YES", confidence=0.92, reasoning="Flames visible.", model_id="gpt-5.4"): + resp = MagicMock() + resp.status_code = 201 + resp.json.return_value = { + "id": "vlmv_test123", + "type": "vlm_verification", + "created_at": "2025-06-17T00:00:00Z", + "query": "Is there a fire?", + "model_id": model_id, + "result": {"verdict": verdict, "confidence": confidence, "reasoning": reasoning}, + "cost": {"input_tokens": 400, "output_tokens": 80, "total_cost_usd": 0.0015}, + } + resp.raise_for_status = MagicMock() + return resp + + +def test_returns_vlm_verification_result(gl: Groundlight): + """Result fields are correctly unpacked from the server response JSON.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response() + result = gl.ask_vlm(media=np.zeros((100, 100, 3), dtype=np.uint8), query="Is there a fire?") + + assert isinstance(result, VLMVerificationResult) + assert result.verdict == "YES" + assert result.confidence == pytest.approx(0.92) + assert result.id == "vlmv_test123" + assert result.input_tokens == 400 + assert result.total_cost_usd == pytest.approx(0.0015) + + +def test_numpy_image_encoded_as_jpeg_multipart(gl: Groundlight): + """A numpy array is converted to JPEG and sent as a multipart 'media' part.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response() + gl.ask_vlm(media=np.zeros((480, 640, 3), dtype=np.uint8), query="Is there a fire?") + + _, kwargs = mock_requests.post.call_args + files = kwargs["files"] + assert len(files) == 1 + assert files[0][0] == "media" + _name, data, ctype = files[0][1] + assert ctype == "image/jpeg" + assert len(data) > 0 + + +def test_query_sent_as_form_field_not_url_param(gl: Groundlight): + """query and model_id go in the multipart body — never the URL — so the prompt + doesn't leak into access logs.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response(model_id="nova-pro") + gl.ask_vlm(media=np.zeros((100, 100, 3), dtype=np.uint8), query="Is there a fire?", model_id="nova-pro") + + _, kwargs = mock_requests.post.call_args + assert kwargs["data"]["query"] == "Is there a fire?" + assert kwargs["data"]["model_id"] == "nova-pro" + assert "params" not in kwargs or not kwargs.get("params") + + +def test_more_than_eight_media_raises(gl: Groundlight): + """Supplying more than 8 media items raises ValueError before any network call.""" + with pytest.raises(ValueError, match="at most 8"): + gl.ask_vlm(media=[np.zeros((100, 100, 3), dtype=np.uint8)] * 9, query="test") + + +def test_url_has_correct_path(gl: Groundlight): + """sanitize_endpoint_url strips the trailing slash from self.endpoint, so the path + must include a leading '/' — without it the URL becomes '...device-apiv1/...'.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response() + gl.ask_vlm(media=np.zeros((100, 100, 3), dtype=np.uint8), query="test") + + args, _ = mock_requests.post.call_args + assert "/device-api/v1/vlm-verifications" in args[0]