TrainApi

class TrainApi(api)[source]

Bases: object

High-level API to start a training application on the Supervisely instance.

You can read more about in the Training API documentation.

This wrapper prepares the params/state payload expected by the training UI application and starts an app task on a given agent.

Typical usage:

  • Choose a model (pretrained or custom checkpoint)

  • Provide training settings (model, classes, train/val split, hyperparameters, etc.)

  • Start the training app

Parameters:
api

Api object to use for API connection.

Methods

find_train_app_by_framework

Find a training app module for the given framework.

finetune

Finetune best checkpoint from a previous training task.

rerun_from_task_id

Rerun a training application task from a previous task.

run

Start a training application task for a project.

find_train_app_by_framework(framework)[source]

Find a training app module for the given framework.

Parameters:
framework : str

Framework name (e.g. “RT-DETRv2”, “YOLO”, “DEIM”).

Returns:

Ecosystem module dict (as returned by the API) or None if not found.

Return type:

Union[dict, None]

finetune(task_id, project_id=None, agent_id=None, **kwargs)[source]

Finetune best checkpoint from a previous training task.

Parameters:
task_id : int

Task ID to finetune best checkpoint from.

project_id : int

Project ID to train on. If not provided, will use project ID from the provided Task ID

agent_id : int

Agent ID where the app task will run. If not provided, will use the most available agent from the current team.

Returns:

Task information dict for the created app task.

Return type:

Dict[str, Any]

Raises:

ValueError – If a suitable training app cannot be found for the detected framework.

Usage Example:
import os
from dotenv import load_dotenv

import supervisely as sly
from supervisely.api.nn.train_api import TrainApi

# Load secrets and create API object from .env file (recommended)
# Learn more here: https://developer.supervisely.com/getting-started/basics-of-authentication
if sly.is_development():
    load_dotenv("local.env")
    load_dotenv(os.path.expanduser("~/supervisely.env"))

api = sly.Api.from_env()

agent_id = sly.env.agent_id()
project_id = sly.env.project_id()

train = TrainApi(api)
train.finetune(agent_id, project_id, task_id=123)
rerun_from_task_id(task_id, project_id=None, agent_id=None, **kwargs)[source]

Rerun a training application task from a previous task.

Parameters:
task_id : int

Task ID to rerun from.

project_id : int

Project ID to train on. If not provided, will use project ID from the provided Task ID

agent_id : int

Agent ID where the app task will run. If not provided, will use the most available agent from the current team.

Returns:

Task information dict for the created app task.

Return type:

Dict[str, Any]

Raises:

ValueError – If a suitable training app cannot be found for the detected framework.

Usage Example:
import os
from dotenv import load_dotenv

import supervisely as sly
from supervisely.api.nn.train_api import TrainApi

# Load secrets and create API object from .env file (recommended)
# Learn more here: https://developer.supervisely.com/getting-started/basics-of-authentication
if sly.is_development():
    load_dotenv("local.env")
    load_dotenv(os.path.expanduser("~/supervisely.env"))

api = sly.Api.from_env()

agent_id = sly.env.agent_id()
project_id = sly.env.project_id()

train = TrainApi(api)
train.rerun_from_task_id(agent_id, project_id, task_id=123)
run(project_id, model, hyperparameters=None, experiment_name=None, classes=None, train_val_split=None, convert_class_shapes=True, enable_benchmark=True, enable_speedtest=False, cache_project=True, export_onnx=False, export_tensorrt=False, autostart=True, agent_id=None, **kwargs)[source]

Start a training application task for a project.

You can read more about the training API in the Training API documentation.

Parameters:
agent_id : int

Agent ID where the app task will run.

project_id : int

Project ID to train on.

model : str

Either a checkpoint path in Team Files (e.g. “/experiments/…/checkpoints/best.pth”), or a pretrained model name in format “framework/model_name” (e.g. “RT-DETRv2/RT-DETRv2-M”, “YOLO/YOLO26s-det”).

hyperparameters : str, optional

Hyperparameters YAML string for the training app. If None, uses defaults from training app.

experiment_name : str, optional

Optional experiment name used in training app. Will be auto-generated if not provided.

classes : List[str], optional

Optional subset of class names to train on. If provided, names not present in project meta are ignored.

train_val_split=None

Optional split strategy; defaults to RandomSplit.

convert_class_shapes : bool, optional

Whether to convert class shapes to framework requirements automatically.

enable_benchmark : bool, optional

Enable post-training evaluation (Model Benchmark) after training.

enable_speedtest : bool, optional

Enable speed test as part of benchmark after training.

cache_project : bool, optional

Cache project on agent before training to save time on downloading project next time.

export_onnx : bool, optional

Enable export to ONNXRuntime (if supported by the training app/framework).

export_tensorrt : bool, optional

Enable export to TensorRT engine (if supported by the training app/framework).

autostart : bool, optional

If True, training is started automatically after all settings are applied. If False, training must be started manually from the training app UI by clicking the “Start Training” button.

Returns:

Task information dict for the created app task.

Return type:

Dict[str, Any]

Raises:

ValueError – If a suitable training app cannot be found for the detected framework.

Usage Example:
import os
from dotenv import load_dotenv

import supervisely as sly
from supervisely.api.nn.train_api import TrainApi

# Load secrets and create API object from .env file (recommended)
# Learn more here: https://developer.supervisely.com/getting-started/basics-of-authentication
if sly.is_development():
    load_dotenv("local.env")
    load_dotenv(os.path.expanduser("~/supervisely.env"))

api = sly.Api.from_env()

agent_id = sly.env.agent_id()
project_id = sly.env.project_id()

train = TrainApi(api)
train.run(agent_id, project_id, model="YOLO/YOLO11n-det")