Source code for supervisely.api.nn.train_api

"""API for training neural network models in Supervisely."""

from __future__ import annotations

from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union

import supervisely.io.env as env
from supervisely import ProjectMeta
from supervisely._utils import logger
from supervisely.api.api import Api
from supervisely.api.nn.utils import (
    find_agent,
    find_apps_by_framework,
    find_team_by_path,
    get_artifacts_dir_and_checkpoint_name,
    get_experiment_info_by_task_id,
    run_train_app,
)
from supervisely.nn.experiments import get_experiment_info_by_artifacts_dir


class Model:
    """
    This class normalizes the user input into a structure that can be embedded into an
    application state (`guiState["model"]`) for a training app.

    The training UI supports two model sources:
    - **Pretrained models**: referenced by `"framework/model_name"` (e.g. `"RT-DETRv2/RT-DETRv2-M"`).
    - **Custom models**: referenced by a path to a checkpoint in Team Files (starts with `"/"`),
      or by a path without leading slash if it exists in Team Files.
    """

    def __init__(
        self,
        api: "Api",
        source: str,
        framework: str = None,
        model_name: str = None,
        team_id: int = None,
        artifacts_dir: str = None,
        checkpoint_name: str = None
    ):
        """
        :param api: :class:`~supervisely.api.api.Api` object to use for API connection.
        :type api: :class:`~supervisely.api.api.Api`
        :param source: Source of the model.
        :type source: str
        :param framework: Framework of the model.
        :type framework: str, optional
        :param model_name: Name of the model.
        :type model_name: str, optional
        :param team_id: Team id of the model.
        :type team_id: int, optional
        :param artifacts_dir: Artifacts directory of the model.
        :type artifacts_dir: str, optional
        :param checkpoint_name: Checkpoint name of the model.
        :type checkpoint_name: str, optional
        """
        self.api = api
        self.source: str = source
        self.framework: str = framework
        self.model_name: str = model_name
        self.team_id: int = team_id
        self.artifacts_dir: str = artifacts_dir
        self.checkpoint_name: str = checkpoint_name
        self._init()

    def _init(self):
        """
        Finalize initialization for custom models.

        For custom checkpoints, we resolve experiment metadata to determine:
        - `task_id` of the experiment that produced artifacts.
        - `framework` name used by that experiment.
        """
        if self.source == "Custom models":
            if self.team_id is None:
                self.team_id = find_team_by_path(self.api, self.artifacts_dir, team_id=self.team_id)
            experiment_info = get_experiment_info_by_artifacts_dir(self.api, self.team_id, self.artifacts_dir)
            if not experiment_info:
                raise ValueError(
                    f"Failed to retrieve experiment info for artifacts_dir: '{self.artifacts_dir}'"
                )
            self.task_id = experiment_info.task_id
            self.framework = experiment_info.framework_name

    @classmethod
    def parse(cls, api: "Api", model: str) -> Model:
        """
        Parse a user-provided model identifier into a `Model`.

        :param api: Supervisely API client.
        :type api: :class:`~supervisely.api.api.Api`
        :param model: Either a checkpoint path starting with "/" (Team Files), a checkpoint path
            without leading slash (will be checked in Team Files), or a pretrained model name in
            format "framework/model_name" (e.g. "RT-DETRv2/RT-DETRv2-M", "YOLO/YOLO26s-det").
        :type model: str
        :returns: Parsed model reference ready to be embedded into training app state.
        :rtype: :class:`~supervisely.api.nn.train_api.Model`
        :raises ValueError: If the string cannot be parsed or required metadata cannot be resolved.
        """
        checkpoint = None
        pretrained = model
        team_id = None
        if model.startswith("/"):
            checkpoint = model
        else:
            found_team_id = find_team_by_path(api,
                f"/{model}", team_id=team_id, raise_not_found=False
            )
            if found_team_id is not None:
                checkpoint = f"/{model}"
                team_id = found_team_id
                logger.debug(f"Found checkpoint in team {team_id}")
            else:
                pretrained = model

        if checkpoint is not None:
            artifacts_dir, checkpoint_name = get_artifacts_dir_and_checkpoint_name(checkpoint)
            return cls(
                api=api,
                source="Custom models",
                artifacts_dir=artifacts_dir,
                checkpoint_name=checkpoint_name,
                team_id=team_id,
            )
        else:
            framework, model_name = pretrained.split("/", 1)
            return cls(
                api=api, source="Pretrained models", framework=framework, model_name=model_name
            )

    def app_state(self):
        """Return a JSON-serializable dict for `guiState["model"]`."""
        if self.source == "Pretrained models":
            return {
                "source": self.source,
                "model_name": self.model_name,
            }
        else:
            return {
                "source": self.source,
                "task_id": self.task_id,
                "checkpoint": self.checkpoint_name
            }


class _TrainValSplit:
    """Base class for Train/Val split strategies used by training UI."""

    @abstractmethod
    def app_state(self) -> Dict[str, Any]:
        """Return a JSON-serializable dict for `guiState["train_val_split"]`.

        :returns: Dict for training app UI state.
        :rtype: Dict[str, Any]
        """
        raise NotImplementedError()


class RandomSplit(_TrainValSplit):
    """Split dataset randomly by percent."""

    def __init__(self, percent: int = 80, split: str = "train"):
        """
        :param percent: Percent of the dataset to split into train.
        :type percent: int
        :param split: Split method: "train" or "val".
        :type split: str
        """
        self.percent = percent
        self.split = split

    def app_state(self):
        """Return a JSON-serializable dict for `guiState["train_val_split"]`."""
        return {
            "method": "random",
            "split": self.split,
            "percent": self.percent
        }

class DatasetsSplit(_TrainValSplit):
    """Split by explicit train/val dataset ids inside the project."""

    def __init__(self, train_datasets: List[int], val_datasets: List[int]):
        """
        :param train_datasets: List of dataset ids to split into train.
        :type train_datasets: List[int]
        :param val_datasets: List of dataset ids to split into val.
        :type val_datasets: List[int]
        """
        self.train_datasets = train_datasets
        self.val_datasets = val_datasets

    def app_state(self):
        """Serialize split settings for training UI."""
        return {
            "method": "datasets",
            "train_datasets": self.train_datasets,
            "val_datasets": self.val_datasets
        }

class TagsSplit(_TrainValSplit):
    """Split by tags: items with train_tag go to train, val_tag to val."""

    def __init__(self, train_tag: str, val_tag: str, untagged_action: Literal["train", "val", "ignore"]):
        """
        :param train_tag: Tag to split into train.
        :type train_tag: str
        :param val_tag: Tag to split into val.
        :type val_tag: str
        :param untagged_action: Action for untagged items: "train", "val", "ignore".
        :type untagged_action: Literal["train", "val", "ignore"]
        """
        self.train_tag = train_tag
        self.val_tag = val_tag
        self.untagged_action = untagged_action

    def app_state(self):
        """Serialize split settings for training UI."""
        return {
            "method": "tags",
            "train_tag": self.train_tag,
            "val_tag": self.val_tag,
            "untagged_action": self.untagged_action
        }

class CollectionsSplit(_TrainValSplit):
    """Split by entity collections (train collections vs val collections)."""

    def __init__(self, train_collections: List[int], val_collections: List[int]):
        """
        :param train_collections: List of collection ids to split into train.
        :type train_collections: List[int]
        :param val_collections: List of collection ids to split into val.
        :type val_collections: List[int]
        """
        self.train_collections = train_collections
        self.val_collections = val_collections

    def app_state(self):
        """Serialize split settings for training UI."""
        return {
            "method": "collections",
            "train_collections": self.train_collections,
            "val_collections": self.val_collections
        }

[docs] class TrainApi: """ High-level API to start a training application on the Supervisely instance. You can read more about in the `Training API documentation <https://developer.supervisely.com/advanced-user-guide/automate-with-python-sdk-and-api/training-api>`_. This wrapper prepares the `params`/`state` payload expected by the training UI application and starts an app task on a given agent. Typical usage: - Choose a model (pretrained or custom checkpoint) - Provide training settings (model, classes, train/val split, hyperparameters, etc.) - Start the training app """ def __init__(self, api: "Api"): """ :param api: :class:`~supervisely.api.api.Api` object to use for API connection. :type api: :class:`~supervisely.api.api.Api` """ self._api = api def _get_app_state( self, project_id: int, model: Model, classes: List[str], train_val_split: _TrainValSplit, experiment_name: str = None, hyperparameters: str = None, convert_class_shapes: bool = True, enable_benchmark: bool = True, enable_speedtest: bool = False, cache_project: bool = True, export_onnx: bool = False, export_tensorrt: bool = False, autostart: bool = True, ): """ Build training application state payload. The resulting structure is passed to the training app as `params` (task arguments). It follows the `TrainApp`/Training UI expected schema: - `state.slyProjectId`: project id - `state.guiState`: Training UI state (model/classes/split/hyperparameters/options/experiment_name/start_training) Notes: - `hyperparameters` is expected to be a YAML string (or `None` to keep defaults from training application). :param project_id: Project id to train on. :type project_id: int :param model: Parsed model reference ready to be embedded into training application state. :type model: str :param classes: Class names to train on (filtered to project meta upstream). :type classes: List[str] :param train_val_split: Train/Val split strategy for the Training UI. :type train_val_split: :class:`~supervisely.api.nn.train_api._TrainValSplit` :param experiment_name: Optional experiment name shown in UI and used for artifacts. :type experiment_name: str, optional :param hyperparameters: Hyperparameters YAML string for the training application. If None, Training UI keeps defaults. :type hyperparameters: str, optional :param convert_class_shapes: Whether to auto-convert shapes to framework requirements. :type convert_class_shapes: bool :param enable_benchmark: Enable post-training evaluation (Model Benchmark). :type enable_benchmark: bool :param enable_speedtest: Enable speed test as part of benchmark. :type enable_speedtest: bool :param cache_project: Cache project on agent before training. :type cache_project: bool :param export_onnx: Enable export to ONNXRuntime (if supported by the app/framework). :type export_onnx: bool :param export_tensorrt: Enable export to TensorRT engine (if supported by the app/framework). :type export_tensorrt: bool :param autostart: If True, training is started automatically after UI state is applied. :type autostart: bool :returns: Task params payload for the training app. :rtype: Dict[str, Any] """ app_state = { "state": { # 1. Project "slyProjectId": project_id, "guiState": { # 2. Model "model": model.app_state(), # 3. Classes "classes": classes, # 4. Train/Val Split "train_val_split": train_val_split.app_state(), # 5. Hyperparameters "hyperparameters": hyperparameters, # yaml string # 6. Options "options": { "convert_class_shapes": convert_class_shapes, "model_benchmark": { "enable": enable_benchmark or enable_speedtest, "speed_test": enable_speedtest, }, "cache_project": cache_project, "export": { "enable": export_onnx or export_tensorrt, "ONNXRuntime": export_onnx, "TensorRT": export_tensorrt, }, }, # 7. Experiment Name "experiment_name": experiment_name, # 8. Start Training "start_training": autostart, # Starts training automatically }, }, } return app_state
[docs] def run( self, project_id: int, model: str, hyperparameters: str = None, experiment_name: str = None, classes: List[str] = None, train_val_split: Union[RandomSplit, DatasetsSplit, TagsSplit, CollectionsSplit] = None, convert_class_shapes: bool = True, enable_benchmark: bool = True, enable_speedtest: bool = False, cache_project: bool = True, export_onnx: bool = False, export_tensorrt: bool = False, autostart: bool = True, agent_id: int = None, **kwargs, ): """ Start a training application task for a project. You can read more about the training API in the `Training API documentation <https://developer.supervisely.com/advanced-user-guide/automate-with-python-sdk-and-api/training-api>`_. :param agent_id: Agent ID where the app task will run. :type agent_id: int :param project_id: Project ID to train on. :type project_id: int :param model: Either a checkpoint path in Team Files (e.g. "/experiments/.../checkpoints/best.pth"), or a pretrained model name in format "framework/model_name" (e.g. "RT-DETRv2/RT-DETRv2-M", "YOLO/YOLO26s-det"). :type model: str :param hyperparameters: Hyperparameters YAML string for the training app. If None, uses defaults from training app. :type hyperparameters: str, optional :param experiment_name: Optional experiment name used in training app. Will be auto-generated if not provided. :type experiment_name: str, optional :param classes: Optional subset of class names to train on. If provided, names not present in project meta are ignored. :type classes: List[str], optional :param train_val_split: Optional split strategy; defaults to :class:`~supervisely.api.nn.train_api.RandomSplit`. :type train_val_split: Union[:class:`~supervisely.api.nn.train_api.RandomSplit`, :class:`~supervisely.api.nn.train_api.DatasetsSplit`, :class:`~supervisely.api.nn.train_api.TagsSplit`, :class:`~supervisely.api.nn.train_api.CollectionsSplit`], optional :param convert_class_shapes: Whether to convert class shapes to framework requirements automatically. :type convert_class_shapes: bool, optional :param enable_benchmark: Enable post-training evaluation (Model Benchmark) after training. :type enable_benchmark: bool, optional :param enable_speedtest: Enable speed test as part of benchmark after training. :type enable_speedtest: bool, optional :param cache_project: Cache project on agent before training to save time on downloading project next time. :type cache_project: bool, optional :param export_onnx: Enable export to ONNXRuntime (if supported by the training app/framework). :type export_onnx: bool, optional :param export_tensorrt: Enable export to TensorRT engine (if supported by the training app/framework). :type export_tensorrt: bool, optional :param autostart: If True, training is started automatically after all settings are applied. If False, training must be started manually from the training app UI by clicking the "Start Training" button. :type autostart: bool, optional :returns: Task information dict for the created app task. :rtype: Dict[str, Any] :raises ValueError: If a suitable training app cannot be found for the detected framework. :Usage Example: .. code-block:: python import os from dotenv import load_dotenv import supervisely as sly from supervisely.api.nn.train_api import TrainApi # Load secrets and create API object from .env file (recommended) # Learn more here: https://developer.supervisely.com/getting-started/basics-of-authentication if sly.is_development(): load_dotenv("local.env") load_dotenv(os.path.expanduser("~/supervisely.env")) api = sly.Api.from_env() agent_id = sly.env.agent_id() project_id = sly.env.project_id() train = TrainApi(api) train.run(agent_id, project_id, model="YOLO/YOLO11n-det") """ model: Model = Model.parse(self._api, model) return self._run( project_id=project_id, model=model, hyperparameters=hyperparameters, experiment_name=experiment_name, classes=classes, train_val_split=train_val_split, convert_class_shapes=convert_class_shapes, enable_benchmark=enable_benchmark, enable_speedtest=enable_speedtest, cache_project=cache_project, export_onnx=export_onnx, export_tensorrt=export_tensorrt, autostart=autostart, agent_id=agent_id, **kwargs )
def _run( self, project_id: int, model: Model, hyperparameters: str = None, experiment_name: str = None, classes: List[str] = None, train_val_split: Union[RandomSplit, DatasetsSplit, TagsSplit, CollectionsSplit] = None, convert_class_shapes: bool = True, enable_benchmark: bool = True, enable_speedtest: bool = True, cache_project: bool = True, export_onnx: bool = False, export_tensorrt: bool = False, autostart: bool = True, agent_id: int = None, **kwargs ) -> Dict: project_info = self._api.project.get_info_by_id(project_id) workspace_id = project_info.workspace_id team_id = project_info.team_id project_meta_json = self._api.project.get_meta(project_id) project_meta = ProjectMeta.from_json(project_meta_json) if classes: classes = [obj_class.name for obj_class in project_meta.obj_classes if obj_class.name in classes] else: classes = [obj_class.name for obj_class in project_meta.obj_classes] if train_val_split is None: train_val_split = RandomSplit() module = self.find_train_app_by_framework(model.framework) if module is None: raise ValueError(f"Failed to detect train app by framework: '{model.framework}'") module_id = module["id"] if agent_id is None: agent_id = find_agent(self._api, team_id) app_state = self._get_app_state( project_id=project_id, model=model, classes=classes, train_val_split=train_val_split, experiment_name=experiment_name, hyperparameters=hyperparameters, convert_class_shapes = convert_class_shapes, enable_benchmark = enable_benchmark, enable_speedtest = enable_speedtest, cache_project = cache_project, export_onnx = export_onnx, export_tensorrt = export_tensorrt, autostart = autostart, ) task_info = run_train_app( api=self._api, agent_id=agent_id, module_id=module_id, workspace_id=workspace_id, app_state=app_state, timeout=100, **kwargs ) return task_info
[docs] def find_train_app_by_framework(self, framework: str): """ Find a training app module for the given framework. :param framework: Framework name (e.g. "RT-DETRv2", "YOLO", "DEIM"). :type framework: str :returns: Ecosystem module dict (as returned by the API) or None if not found. :rtype: Union[dict, None] """ modules = find_apps_by_framework(self._api, framework, ["train"]) if not modules: return None return modules[0]
[docs] def finetune( self, task_id: int, project_id: int = None, agent_id: int = None, **kwargs ) -> Dict: """ Finetune best checkpoint from a previous training task. :param task_id: Task ID to finetune best checkpoint from. :type task_id: int :param project_id: Project ID to train on. If not provided, will use project ID from the provided Task ID :type project_id: int :param agent_id: Agent ID where the app task will run. If not provided, will use the most available agent from the current team. :type agent_id: int :returns: Task information dict for the created app task. :rtype: Dict[str, Any] :raises ValueError: If a suitable training app cannot be found for the detected framework. :Usage Example: .. code-block:: python import os from dotenv import load_dotenv import supervisely as sly from supervisely.api.nn.train_api import TrainApi # Load secrets and create API object from .env file (recommended) # Learn more here: https://developer.supervisely.com/getting-started/basics-of-authentication if sly.is_development(): load_dotenv("local.env") load_dotenv(os.path.expanduser("~/supervisely.env")) api = sly.Api.from_env() agent_id = sly.env.agent_id() project_id = sly.env.project_id() train = TrainApi(api) train.finetune(agent_id, project_id, task_id=123) """ experiment_info = get_experiment_info_by_task_id(self._api, task_id) if experiment_info is None: raise ValueError(f"Not found experiment data for task {task_id}") module = self.find_train_app_by_framework(experiment_info.framework_name) if module is None: raise ValueError(f"Failed to detect train app by framework: '{experiment_info.framework_name}'") if project_id is None: project_id = experiment_info.project_id project_info = self._api.project.get_info_by_id(project_id) team_id = project_info.team_id if agent_id is None: agent_id = find_agent(self._api, team_id) module_id = module["id"] app_state = {"state": {"trainTaskId": task_id, "trainMode": "continue", "slyProjectId": project_id}} task_info = run_train_app( api=self._api, agent_id=agent_id, module_id=module_id, timeout=100, workspace_id=project_info.workspace_id, app_state=app_state, **kwargs ) return task_info
[docs] def rerun_from_task_id( self, task_id: int, project_id: int = None, agent_id: int = None, **kwargs ): """ Rerun a training application task from a previous task. :param task_id: Task ID to rerun from. :type task_id: int :param project_id: Project ID to train on. If not provided, will use project ID from the provided Task ID :type project_id: int :param agent_id: Agent ID where the app task will run. If not provided, will use the most available agent from the current team. :type agent_id: int :returns: Task information dict for the created app task. :rtype: Dict[str, Any] :raises ValueError: If a suitable training app cannot be found for the detected framework. :Usage Example: .. code-block:: python import os from dotenv import load_dotenv import supervisely as sly from supervisely.api.nn.train_api import TrainApi # Load secrets and create API object from .env file (recommended) # Learn more here: https://developer.supervisely.com/getting-started/basics-of-authentication if sly.is_development(): load_dotenv("local.env") load_dotenv(os.path.expanduser("~/supervisely.env")) api = sly.Api.from_env() agent_id = sly.env.agent_id() project_id = sly.env.project_id() train = TrainApi(api) train.rerun_from_task_id(agent_id, project_id, task_id=123) """ experiment_info = get_experiment_info_by_task_id(self._api, task_id) if experiment_info is None: raise ValueError(f"Not found experiment data for task {task_id}") module = self.find_train_app_by_framework(experiment_info.framework_name) if module is None: raise ValueError(f"Failed to detect train app by framework: '{experiment_info.framework_name}'") if project_id is None: project_id = experiment_info.project_id project_info = self._api.project.get_info_by_id(project_id) team_id = project_info.team_id if agent_id is None: agent_id = find_agent(self._api, team_id) module_id = module["id"] app_state = {"state": {"trainTaskId": task_id, "trainMode": "new", "slyProjectId": project_id}} task_info = run_train_app( api=self._api, agent_id=agent_id, module_id=module_id, timeout=100, workspace_id=project_info.workspace_id, app_state=app_state, **kwargs ) return task_info