"""
Streaming inference sessions for deployed models.
"""
import json
import time
from os import PathLike
from pathlib import Path
from typing import Any, Dict, Iterator, List, Literal, Tuple, Union
import numpy as np
import requests
from requests import Timeout
from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
from tqdm.auto import tqdm
import supervisely.io.env as env
from supervisely._utils import get_valid_kwargs, logger
from supervisely.api.api import Api
from supervisely.imaging._video import ALLOWED_VIDEO_EXTENSIONS
from supervisely.imaging.image import SUPPORTED_IMG_EXTS, write_bytes
from supervisely.io.fs import (
dir_exists,
file_exists,
get_file_ext,
get_file_size,
list_files,
list_files_recursively,
)
from supervisely.io.network_exceptions import process_requests_exception
from supervisely.nn.model.prediction import Prediction
from supervisely.project.project import Dataset, OpenMode, Project
from supervisely.project.project_meta import ProjectMeta
def value_generator(value):
"""
Yield the same value forever.
Used internally to provide per-item kwargs during iteration.
"""
while True:
yield value
[docs]
class PredictionSession:
"""
Asynchronous inference session that yields :class:`~supervisely.nn.model.prediction.Prediction`.
The session starts inference immediately during construction and becomes an iterator.
Use it directly when you need streaming results or progress control, or use higher-level helpers
like :meth:`~supervisely.nn.model.model_api.ModelAPI.predict_detached` method.
"""
[docs]
class Iterator:
"""Internal iterator that fetches pending results in chunks."""
def __init__(self, total, session: "PredictionSession", tqdm: tqdm = None):
"""
:param total: Total items.
:type total: int
:param session: PredictionSession.
:type session: :class:`~supervisely.nn.model.prediction_session.PredictionSession`
:param tqdm: Optional progress bar.
:type tqdm: tqdm
"""
self.total = total
self.session = session
self.results_queue = []
self.tqdm = tqdm
def __len__(self) -> int:
return self.total
def __iter__(self) -> Iterator:
return self
def __next__(self) -> Dict[str, Any]:
if not self.results_queue:
pending_results = self.session._wait_for_pending_results(tqdm=self.tqdm)
self.results_queue += pending_results
if not self.results_queue:
raise StopIteration
pred = self.results_queue.pop(0)
return pred
def __init__(
self,
url: str,
input: Union[np.ndarray, str, PathLike, list] = None,
image_id: Union[List[int], int] = None,
video_id: Union[List[int], int] = None,
dataset_id: Union[List[int], int] = None,
project_id: Union[List[int], int] = None,
api: "Api" = None,
tracking: bool = None,
tracking_config: dict = None,
**kwargs: dict,
):
"""
Exactly one of the following inputs must be provided: ``input``, ``image_id``, ``video_id``,
``dataset_id``, ``project_id`` (or their plural forms in ``kwargs``: ``image_ids`` etc.).
:param url: Deployment base URL (e.g. ``https://.../net/<sessionToken>``).
:type url: str
:param input: Local input (NumPy image, image/video path, directory, or list of them).
:type input: numpy.ndarray or str or PathLike or list, optional
:param image_id: Image id (or list in ``image_ids`` kwarg).
:type image_id: int or List[int], optional
:param video_id: Video id (single video supported).
:type video_id: int or List[int], optional
:param dataset_id: Dataset id(s). Requires ``api`` to resolve datasets/projects.
:type dataset_id: int or List[int], optional
:param project_id: Project id (single project supported).
:type project_id: int or List[int], optional
:param api: API client used for downloading by id / resolving datasets/projects.
:type api: :class:`~supervisely.api.api.Api`, optional
:param tracking: Enable video tracking (if supported by deployment).
:type tracking: bool, optional
:param tracking_config: Tracking configuration dict (may include ``tracker`` key).
:type tracking_config: dict, optional
:param kwargs: Additional inference settings (confidence, classes, batch size, etc.).
:type kwargs: dict
:raises AssertionError: If more than one input source is provided.
:raises ValueError: For unsupported inputs or invalid combinations.
:Usage Example:
.. code-block:: python
import os
from dotenv import load_dotenv
import supervisely as sly
from supervisely.nn.model.prediction_session import PredictionSession
# Load secrets and create API object from .env file (recommended)
# Learn more here: https://developer.supervisely.com/getting-started/basics-of-authentication
if sly.is_development():
load_dotenv(os.path.expanduser("~/supervisely.env"))
api = sly.Api.from_env()
session = PredictionSession(
url='https://app.supervisely.com/net/<sessionToken>',
image_id=123,
api=api,
conf=0.25,
)
for pred in session:
_ = pred.boxes
_ = pred.scores
"""
extra_input_args = ["image_ids", "video_ids", "dataset_ids", "project_ids"]
assert (
sum(
[
x is not None
for x in [
input,
image_id,
video_id,
dataset_id,
project_id,
*[kwargs.get(extra_input, None) for extra_input in extra_input_args],
]
]
)
== 1
), "Exactly one of input, image_ids, video_id, dataset_id, project_id or image_id must be provided."
self._iterator = None
self._base_url = url
self.inference_request_uuid = None
self.input = input
self.api = api
self.api_token = self._get_api_token()
self._model_meta = None
self.final_result = None
if "stride" in kwargs:
kwargs["step"] = kwargs["stride"]
if "start_frame" in kwargs:
kwargs["start_frame_index"] = kwargs["start_frame"]
if "num_frames" in kwargs:
kwargs["frames_count"] = kwargs["num_frames"]
self.kwargs = kwargs
if kwargs.get("show_progress", False) and "tqdm" not in kwargs:
kwargs["tqdm"] = tqdm()
self.tqdm = kwargs.pop("tqdm", None)
self.inference_settings = {
k: v for k, v in kwargs.items() if isinstance(v, (str, int, float))
}
if tracking is True:
model_info = self._get_session_info()
if not model_info.get("tracking_on_videos_support", False):
raise ValueError("Tracking is not supported by this model")
if tracking_config is None:
self.tracker = "botsort"
self.tracker_settings = {}
else:
cfg = dict(tracking_config)
self.tracker = cfg.pop("tracker", "botsort")
self.tracker_settings = cfg
else:
self.tracker = None
self.tracker_settings = None
if "classes" in kwargs:
self.inference_settings["classes"] = kwargs["classes"]
# TODO: remove "settings", it is the same as inference_settings
if "settings" in kwargs:
self.inference_settings.update(kwargs["settings"])
if "inference_settings" in kwargs:
self.inference_settings.update(kwargs["inference_settings"])
# extra input args
image_ids = self._set_var_from_kwargs("image_ids", kwargs, image_id)
video_ids = self._set_var_from_kwargs("video_ids", kwargs, video_id)
dataset_ids = self._set_var_from_kwargs("dataset_ids", kwargs, dataset_id)
project_ids = self._set_var_from_kwargs("project_ids", kwargs, project_id)
source = next(
x
for x in [
input,
image_id,
video_id,
dataset_id,
project_id,
image_ids,
video_ids,
dataset_ids,
project_ids,
]
if x is not None
)
self.kwargs["source"] = source
self.prediction_kwargs_iterator = value_generator({})
if not isinstance(input, list):
input = [input]
if isinstance(input[0], np.ndarray):
# input is numpy array
self._predict_images(input, **kwargs)
elif isinstance(input[0], (str, PathLike)):
if len(input) > 1:
# if the input is a list of paths, assume they are images
for x in input:
if not isinstance(x, (str, PathLike)):
raise ValueError("Input must be a list of strings or PathLike objects.")
self._iterator = self._predict_images_bytes(input, **kwargs)
else:
if dir_exists(input[0]):
try:
project = Project(str(input[0]), mode=OpenMode.READ)
except Exception:
project = None
image_paths = []
if project is not None:
for dataset in project.datasets:
dataset: Dataset
for _, image_path, _ in dataset.items():
image_paths.append(image_path)
else:
# if the input is a directory, assume it contains images
recursive = kwargs.get("recursive", False)
if recursive:
image_paths = list_files_recursively(
input[0], valid_extensions=SUPPORTED_IMG_EXTS
)
else:
image_paths = list_files(input[0], valid_extensions=SUPPORTED_IMG_EXTS)
if len(image_paths) == 0:
raise ValueError("Directory is empty.")
self._iterator = self._predict_images(image_paths, **kwargs)
elif file_exists(input[0]):
ext = get_file_ext(input[0])
if ext == "":
raise ValueError("File has no extension.")
if ext.lower() in SUPPORTED_IMG_EXTS:
self._iterator = self._predict_images(input, **kwargs)
elif ext.lower() in ALLOWED_VIDEO_EXTENSIONS:
kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
self._iterator = self._predict_videos(input, tracker=self.tracker, tracker_settings=self.tracker_settings, **kwargs)
else:
raise ValueError(
f"Unsupported file extension: {ext}. Supported extensions are: {SUPPORTED_IMG_EXTS + ALLOWED_VIDEO_EXTENSIONS}"
)
else:
raise ValueError(f"File or directory does not exist: {input[0]}")
elif image_ids is not None:
self._iterator = self._predict_images(image_ids, **kwargs)
elif video_ids is not None:
if len(video_ids) > 1:
raise ValueError("Only one video id can be provided.")
kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
self._iterator = self._predict_videos(video_ids, tracker=self.tracker, tracker_settings=self.tracker_settings, **kwargs)
elif dataset_ids is not None:
kwargs = get_valid_kwargs(
kwargs,
self._predict_datasets,
exclude=["dataset_ids"],
)
self._iterator = self._predict_datasets(dataset_ids, **kwargs)
elif project_ids is not None:
if len(project_ids) > 1:
raise ValueError("Only one project id can be provided.")
kwargs = get_valid_kwargs(
kwargs,
self._predict_projects,
exclude=["project_ids"],
)
self._iterator = self._predict_projects(project_ids, **kwargs)
else:
raise ValueError(
"Unknown input type. Supported types are: numpy array, path to a file or directory, ImageInfo, VideoInfo, ProjectInfo, DatasetInfo."
)
def _set_var_from_kwargs(self, key, kwargs, default):
value = kwargs.get(key, default)
if value is None:
return None
if not isinstance(value, list):
value = [value]
return value
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
if exc_type is not None:
return False
def __next__(self):
try:
prediction_json = self._iterator.__next__()
this_kwargs = next(self.prediction_kwargs_iterator)
prediction = Prediction.from_json(
prediction_json, **self.kwargs, **this_kwargs, model_meta=self.model_meta
)
return prediction
except StopIteration:
self._on_infernce_end()
raise
except Exception:
self.stop()
raise
def next(self):
return self.__next__()
def __iter__(self):
return self
def __len__(self):
return len(self._iterator)
def _get_api_token(self):
if self.api is not None:
return self.api.token
return env.api_token(raise_not_found=False)
def _get_json_body(self):
body = {"state": {}, "context": {}}
if self.inference_request_uuid is not None:
body["state"]["inference_request_uuid"] = self.inference_request_uuid
if self.inference_settings:
body["state"]["settings"] = self.inference_settings
if self.api_token is not None:
body["api_token"] = self.api_token
if "model_prediction_suffix" in self.kwargs:
body["state"]["model_prediction_suffix"] = self.kwargs["model_prediction_suffix"]
return body
def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
if kwargs.get("headers") is None:
kwargs["headers"] = {}
if self.api is not None:
retries = min(self.api.retry_count, retries)
if "x-api-key" not in kwargs["headers"]:
kwargs["headers"]["x-api-key"] = self.api.token
url = self._base_url.rstrip("/") + "/" + method.lstrip("/")
if "timeout" not in kwargs:
kwargs["timeout"] = 60
for retry_idx in range(retries):
response = None
try:
logger.trace(f"POST {url}")
response = requests.post(url, *args, **kwargs)
if response.status_code != requests.codes.ok: # pylint: disable=no-member
Api._raise_for_status(response)
return response
except requests.RequestException as exc:
process_requests_exception(
logger,
exc,
method,
url,
verbose=True,
swallow_exc=True,
sleep_sec=5,
response=response,
retry_info={"retry_idx": retry_idx + 1, "retry_limit": retries},
)
if retry_idx + 1 == retries:
raise exc
def _get_session_info(self) -> Dict[str, Any]:
method = "get_session_info"
r = self._post(method, json=self._get_json_body())
return r.json()
def _get_inference_progress(self):
method = "get_inference_progress"
r = self._post(method, json=self._get_json_body())
return r.json()
def _get_inference_status(self):
method = "get_inference_status"
r = self._post(method, json=self._get_json_body())
return r.json()
def _stop_async_inference(self):
method = "stop_inference"
r = self._post(
method,
json=self._get_json_body(),
)
logger.info("Inference will be stopped on the server")
return r.json()
def _clear_inference_request(self):
method = "clear_inference_request"
r = self._post(
method,
json=self._get_json_body(),
)
logger.info("Inference request will be cleared on the server")
return r.json()
def _get_final_result(self):
method = "get_inference_result"
r = self._post(
method,
json=self._get_json_body(),
)
return r.json()
def _on_infernce_end(self):
if self.inference_request_uuid is None:
return
try:
self.final_result = self._get_final_result()
except Exception as e:
logger.debug("Failed to get final result:", exc_info=True)
self._clear_inference_request()
@property
def model_meta(self) -> ProjectMeta:
"""
Lazily fetch output :class:`~supervisely.project.project_meta.ProjectMeta` from the deployment.
:returns: Model meta.
:rtype: :class:`~supervisely.project.project_meta.ProjectMeta`
"""
if self._model_meta is None:
self._model_meta = ProjectMeta.from_json(
self._post("get_model_meta", json=self._get_json_body()).json()
)
return self._model_meta
[docs]
def stop(self):
"""
Stop the current inference request on the server (if running).
Also tries to fetch final result metadata and clears the server-side request state.
"""
if self.inference_request_uuid is None:
logger.debug("No active inference request to stop.")
return
self._stop_async_inference()
self._on_infernce_end()
[docs]
def is_done(self):
"""
Check whether server-side inference is finished.
:returns: True if finished.
:rtype: bool
:raises RuntimeError: If inference has not been started yet.
"""
if self.inference_request_uuid is None:
raise RuntimeError(
"Inference is not started. Please start inference before checking the status."
)
return not self._get_inference_progress()["is_inferring"]
[docs]
def progress(self):
"""
Return numeric progress of the current inference request.
:returns: Progress value as returned by the backend.
:rtype: Any
:raises RuntimeError: If inference has not been started yet.
"""
if self.inference_request_uuid is None:
raise RuntimeError(
"Inference is not started. Please start inference before checking the status."
)
return self._get_inference_progress()["progress"]
[docs]
def status(self):
"""
Return raw status JSON for the current inference request.
:returns: Status dict.
:rtype: dict
:raises RuntimeError: If inference has not been started yet.
"""
if self.inference_request_uuid is None:
raise RuntimeError(
"Inference is not started. Please start inference before checking the status."
)
return self._get_inference_status()
def _pop_pending_results(self) -> Dict[str, Any]:
method = "pop_inference_results"
json_body = self._get_json_body()
return self._post(method, json=json_body).json()
def _update_progress(
self,
tqdm: tqdm,
message: str = None,
current: int = None,
total: int = None,
is_size: bool = False,
):
if tqdm is None:
return
refresh = False
if message is not None and tqdm.desc != message:
tqdm.set_description(message, refresh=False)
refresh = True
if current is not None and tqdm.n != current:
tqdm.n = current
refresh = True
if total is not None and tqdm.total != total:
tqdm.total = total
refresh = True
if is_size and tqdm.unit == "it":
tqdm.unit = "iB"
tqdm.unit_scale = True
tqdm.unit_divisor = 1024
refresh = True
if not is_size and tqdm.unit == "iB":
tqdm.unit = "it"
tqdm.unit_scale = False
tqdm.unit_divisor = 1
refresh = True
if refresh:
tqdm.refresh()
def _update_progress_from_response(self, tqdm: tqdm, response: Dict[str, Any]):
if tqdm is None:
return
json_progress = response.get("progress", None)
if json_progress is None or json_progress.get("message") is None:
json_progress = response.get("preparing_progress", None)
if json_progress is None:
return
message = json_progress.get("message", json_progress.get("status", None))
current = json_progress.get("current", None)
total = json_progress.get("total", None)
is_size = json_progress.get("is_size", False)
self._update_progress(tqdm, message, current, total, is_size)
def _wait_for_inference_start(
self, delay=1, timeout=None, tqdm: tqdm = None
) -> Tuple[dict, bool]:
has_started = False
timeout_exceeded = False
t0 = time.time()
last_stage = None
while not has_started and not timeout_exceeded:
resp = self._get_inference_progress()
stage = resp.get("stage")
if stage != last_stage:
logger.info(stage)
last_stage = stage
has_started = stage not in ["preparing", "preprocess", None]
has_started = has_started or bool(resp.get("result")) or resp["progress"]["total"] != 1
self._update_progress_from_response(tqdm, resp)
if not has_started:
time.sleep(delay)
timeout_exceeded = timeout and time.time() - t0 > timeout
if timeout_exceeded:
self.stop()
raise Timeout("Timeout exceeded. The server didn't start the inference")
return resp, has_started
def _wait_for_pending_results(self, delay=1, timeout=600, tqdm: tqdm = None) -> List[dict]:
logger.debug("waiting pending results...")
has_results = False
timeout_exceeded = False
t0 = time.monotonic()
while not has_results and not timeout_exceeded:
resp = self._pop_pending_results()
self._update_progress_from_response(tqdm, resp)
pending_results = resp["pending_results"]
exception_json = resp["exception"]
if exception_json:
exception_str = f"{exception_json['type']}: {exception_json['message']}"
raise RuntimeError(f"Inference Error: {exception_str}")
has_results = bool(pending_results)
if resp.get("finished", False):
break
if not has_results:
time.sleep(delay)
timeout_exceeded = timeout and time.monotonic() - t0 > timeout
if timeout_exceeded:
self.stop()
raise Timeout("Timeout exceeded. Pending results not received from the server.")
return pending_results
def _start_inference(self, method, **kwargs):
if self.inference_request_uuid:
raise RuntimeError(
"Inference is already running. Please stop it before starting a new one."
)
resp = self._post(method, **kwargs).json()
self.inference_request_uuid = resp["inference_request_uuid"]
try:
resp, has_started = self._wait_for_inference_start(tqdm=self.tqdm)
except:
self.stop()
raise
logger.info(
"Inference has started:",
extra={"inference_request_uuid": resp.get("inference_request_uuid")},
)
frame_iterator = self.Iterator(resp["progress"]["total"], self, tqdm=self.tqdm)
return frame_iterator
def _predict_images(self, images: List, **kwargs: dict):
if isinstance(images[0], bytes):
f = self._predict_images_bytes
elif isinstance(images[0], (str, PathLike)):
f = self._predict_images_paths
elif isinstance(images[0], np.ndarray):
f = self._predict_images_nps
elif isinstance(images[0], int):
f = self._predict_images_ids
else:
raise ValueError(f"Unsupported input type '{type(images[0])}'.")
kwargs = get_valid_kwargs(kwargs, f, exclude=["images"])
return f(images, **kwargs)
def _predict_images_bytes(self, images: List[bytes], batch_size: int = None):
files = [
("files", (f"image_{i}.png", image, "image/png")) for i, image in enumerate(images)
]
state = self._get_json_body()["state"]
if batch_size is not None:
state["batch_size"] = batch_size
method = "inference_batch_async"
uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
return self._start_inference(method, files=uploads)
def _predict_images_paths(self, images: List, batch_size: int = None):
files = []
try:
files = [("files", open(f, "rb")) for f in images]
state = self._get_json_body()["state"]
if batch_size is not None:
state["batch_size"] = batch_size
method = "inference_batch_async"
uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
return self._start_inference(method, files=uploads)
finally:
for _, f in files:
f.close()
def _predict_images_nps(self, images: List, batch_size: int = None):
images = [write_bytes(image, ".png") for image in images]
return self._predict_images_bytes(images, batch_size=batch_size)
def _predict_images_ids(
self,
images: List[int],
batch_size: int = None,
upload_mode: str = None,
output_project_id: int = None,
):
method = "inference_batch_ids_async"
json_body = self._get_json_body()
state = json_body["state"]
state["images_ids"] = images
if batch_size is not None:
state["batch_size"] = batch_size
if upload_mode is not None:
state["upload_mode"] = upload_mode
if output_project_id is not None:
state["output_project_id"] = output_project_id
return self._start_inference(method, json=json_body)
def _predict_videos(
self,
videos: Union[List[int], List[str], List[PathLike]],
start_frame: int = None,
num_frames: int = None,
stride=None,
end_frame=None,
duration=None,
direction: Literal["forward", "backward"] = None,
tracker: Literal["botsort"] = None,
tracker_settings: dict = None,
batch_size: int = None,
):
if len(videos) != 1:
raise ValueError("Only one video can be processed at a time.")
json_body = self._get_json_body()
state = json_body["state"]
for key, value in (
("start_frame", start_frame),
("num_frames", num_frames),
("stride", stride),
("end_frame", end_frame),
("duration", duration),
("direction", direction),
("tracker", tracker),
("tracker_settings", tracker_settings),
("batch_size", batch_size),
):
if value is not None:
state[key] = value
if isinstance(videos[0], int):
method = "inference_video_id_async"
state["video_id"] = videos[0]
return self._start_inference(method, json=json_body)
elif isinstance(videos[0], (str, PathLike)):
video_path = videos[0]
files = []
try:
method = "inference_video_async"
files.append((Path(video_path).name, open(video_path, "rb"), "video/*"))
fields = {
"files": files[-1],
"state": json.dumps(state),
}
encoder = MultipartEncoder(fields)
if self.tqdm is not None:
bytes_read = 0
def _callback(monitor):
nonlocal bytes_read
self.tqdm.update(monitor.bytes_read - bytes_read)
bytes_read = monitor.bytes_read
video_size = get_file_size(video_path)
self._update_progress(self.tqdm, "Uploading video", 0, video_size, is_size=True)
encoder = MultipartEncoderMonitor(encoder, _callback)
return self._start_inference(
method, data=encoder, headers={"Content-Type": encoder.content_type}
)
finally:
for _, f, _ in files:
f.close()
else:
raise ValueError(
f"Unsupported input type '{type(videos[0])}'. Supported types are: int, str, PathLike."
)
def _predict_projects(
self,
project_ids: List[int],
dataset_ids: List[int] = None,
batch_size: int = None,
upload_mode: str = None,
iou_merge_threshold: float = None,
cache_project_on_model: bool = None,
output_project_id: int = None,
):
if len(project_ids) != 1:
raise ValueError("Only one project can be processed at a time.")
method = "inference_project_id_async"
json_body = self._get_json_body()
state = json_body["state"]
state["project_id"] = project_ids[0]
if dataset_ids is not None:
state["dataset_ids"] = dataset_ids
if batch_size is not None:
state["batch_size"] = batch_size
if upload_mode is not None:
state["upload_mode"] = upload_mode
if iou_merge_threshold is not None:
state["iou_merge_threshold"] = iou_merge_threshold
if cache_project_on_model is not None:
state["cache_project_on_model"] = cache_project_on_model
if output_project_id is not None:
state["output_project_id"] = output_project_id
return self._start_inference(method, json=json_body)
def _predict_datasets(
self,
dataset_ids: List[int],
batch_size: int = None,
upload_mode: str = None,
iou_merge_threshold: float = None,
cache_datasets_on_model: bool = None,
):
if self.api is None:
raise ValueError("Api is required to use this method.")
dataset_infos = [self.api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids]
if len(set([info.project_id for info in dataset_infos])) > 1:
raise ValueError("All datasets must belong to the same project.")
return self._predict_projects(
[dataset_infos[0].project_id],
dataset_ids=dataset_ids,
batch_size=batch_size,
upload_mode=upload_mode,
iou_merge_threshold=iou_merge_threshold,
cache_project_on_model=cache_datasets_on_model,
)