TrainApp

class TrainApp(framework_name, models, hyperparameters, app_options=None, work_dir=None)[source]

Bases: object

High-level wrapper for building Supervisely training applications.

Learn more about the Training API documentation.

It connects:

  • GUI (model selector, hyperparameters editor, train/val split selector, class/tag selectors)

  • Data preparation (project download, optional conversion, splitting, collections creation)

  • Training lifecycle (prepare data → run user training code → validate & upload artifacts)

  • Optional extras: export (ONNX/TensorRT), model benchmark, TensorBoard

Parameters:
framework_name : str

Name of the ML framework used (stored in experiment metadata).

models : Union[str, List[Dict[str, Any]]]

Path to .json file (or a Python list) with model configurations for the model selector.

hyperparameters : str

Path to hyperparameters YAML file (.yaml/.yml). The GUI allows editing before training.

app_options : Optional[Union[str, Dict[str, Any]]]

Path to options YAML file or a dict with UI/behavior options.

work_dir : Optional[str]

Local working directory used to store downloaded data, model files, logs and output artifacts.

Usage Example:
from supervisely.nn.training.train_app import TrainApp

train_app = TrainApp(
    framework_name="My Framework", # e.g "YOLO"
    models="models.json",
    hyperparameters="hyperparameters.yaml",
    app_options="app_options.yaml",
)

@train_app.start
def train():
    # Access data via train_app.project_dir
    # Access hyperparameters via train_app.hyperparameters (dict)
    # Access model files via train_app.model_files (dict)
    # Train your model, save checkpoints locally, and return experiment_info.
    return {
        "model_name": train_app.model_name,
        "task_type": train_app.task_type,   # TaskType (string-like enum)
        "checkpoints": "/path/to/checkpoints_dir",  # or list of file paths
        "best_checkpoint": "checkpoint_best.pth",
        # optional: return config if model requires it
        # "model_files": {"config": "/path/to/config.yaml", ...},
    }

Methods

add_output_files

Copies files or directories to the output directory, which will be uploaded to the team files upon training completion.

create_model_meta

Create a model meta for the trained model according to task type.

get_app_state

Returns the current state of the application.

load_app_state

Load the GUI state from app state dictionary.

register_inference_class

Registers an inference class for the training application to do model benchmarking.

start_in_thread

Configure the app to start training automatically on server startup (in a background thread).

start_tensorboard

Method to manually start Tensorboard in the user's training code.

stop_tensorboard

Stop Tensorboard server

Attributes

auto_start

If True, the training will start automatically after the GUI is loaded and train server is started.

base_checkpoint

Returns the name of the base checkpoint.

base_checkpoint_link

Returns the link to the base checkpoint.

classes

Returns the selected classes names for training.

device

Returns the selected device for training.

device_ids

Returns the list of device IDs used for training in multi-GPU mode.

devices

Returns all devices used for training in multi-GPU mode.

export_onnx

Decorator for the export to ONNX function defined by user.

export_tensorrt

Decorator for the export to TensorRT function defined by user.

hyperparameters

Returns the selected hyperparameters for training in dict format.

hyperparameters_yaml

Returns the selected hyperparameters for training in raw format as a string.

is_model_benchmark_enabled

Checks if model benchmarking is enabled based on application options and GUI settings.

is_multi_gpu

Returns True if multi-GPU mode is enabled.

model_info

Returns a selected row from the models table in dict format.

model_name

Returns the name of the model.

model_source

Return whether the model is pretrained or custom.

num_classes

Returns the number of selected classes for training.

num_tags

Returns the number of selected tags for training.

progress_bar_main

Returns the main progress bar widget.

progress_bar_secondary

Returns the secondary progress bar widget.

project_id

Returns the ID of the project.

project_info

Returns ProjectInfo object, which contains information about the project.

project_meta

Returns the project metadata.

project_name

Returns the name of the project.

start

Decorator for the user-defined training function.

tags

Returns the selected tags for training.

task_type

Returns the task type of the model.

team_id

Returns the ID of the team.

workspace_id

Returns the ID of the workspace.

add_output_files(paths)[source]

Copies files or directories to the output directory, which will be uploaded to the team files upon training completion. If path is a file, it will be uploaded to the root artifacts directory. If path is a directory, it will be uploded to the root artifacts directory with the same directory name and structure.

Parameters:
paths : List[str]

List of paths to files or directories to be copied to the output directory.

Returns:

None

Return type:

None

create_model_meta(task_type)[source]

Create a model meta for the trained model according to task type.

This method takes the input project’s ProjectMeta, removes classes that are not selected in the GUI, and then converts meta to a CV-task-specific format (detection/segmentation).

Parameters:
task_type : str

CV task type of the trained model.

Returns:

Model meta to be saved as model_meta.json.

Return type:

ProjectMeta

get_app_state(experiment_info=None)[source]

Returns the current state of the application.

Parameters:
experiment_info : dict

Experiment info.

Returns:

Application state.

Return type:

dict

load_app_state(app_state)[source]

Load the GUI state from app state dictionary.

Parameters:
app_state : Union[str, dict]

The state dictionary or path to the state file.

Example:

app_state = {
    "train_val_split": {"method": "random", "split": "train", "percent": 90},
    "classes": ["apple"],
    # For Pretrained model
    "model": {
        "source": "Pretrained models",
        "model_name": "rtdetr_r50vd_coco_objects365",
    },
    # For Custom model
    # "model": {"source": "Custom models", "task_id": 555, "checkpoint": "checkpoint_10.pth"},
    "hyperparameters": hyperparameters,  # yaml string
    "options": {
        "convert_class_shapes": True,
        "model_benchmark": {"enable": True, "speed_test": True},
        "cache_project": True,
        "export": {"enable": True, "ONNXRuntime": True, "TensorRT": True},
    },
    "experiment_name": "My Experiment",
}
register_inference_class(inference_class, inference_settings=None)[source]

Registers an inference class for the training application to do model benchmarking.

Parameters:
inference_class

Inference class to be registered (must be inherited from Inference).

inference_settings : dict

Settings for the inference class.

start_in_thread()[source]

Configure the app to start training automatically on server startup (in a background thread).

This is useful for programmatic runs (e.g. training triggered by API or workflow execution) where you don’t want to wait for a manual UI click.

Returns:

None

Return type:

None

start_tensorboard(log_dir, port=None)[source]

Method to manually start Tensorboard in the user’s training code. Tensorboard is started automatically if the ‘train_logger’ is set to ‘tensorboard’ in app_options.yaml file.

Parameters:
log_dir : str

Directory path to the log files.

port : int, optional

Port number for Tensorboard, defaults to None

stop_tensorboard()[source]

Stop Tensorboard server

property auto_start : bool

If True, the training will start automatically after the GUI is loaded and train server is started.

property base_checkpoint : str

Returns the name of the base checkpoint.

Returns the link to the base checkpoint.

property classes : list[str]

Returns the selected classes names for training.

Returns:

List of selected classes names.

Return type:

List[str]

property device : str

Returns the selected device for training.

Returns:

Device name.

Return type:

str

property device_ids : list[int]

Returns the list of device IDs used for training in multi-GPU mode.

Returns:

List of device IDs (e.g. [0, 1] for “cuda:0” and “cuda:1”).

Return type:

List[int]

property devices : list[str]

Returns all devices used for training in multi-GPU mode.

Returns:

List of device strings (e.g. [“cuda:0”, “cuda:1”]).

Return type:

List[str]

property export_onnx

Decorator for the export to ONNX function defined by user. It wraps user-defined export function and prepares and finalizes the training process.

property export_tensorrt

Decorator for the export to TensorRT function defined by user. It wraps user-defined export function and prepares and finalizes the training process.

property hyperparameters : dict[str, Any]

Returns the selected hyperparameters for training in dict format.

Returns:

Hyperparameters in dict format.

Return type:

Dict[str, Any]

property hyperparameters_yaml : str

Returns the selected hyperparameters for training in raw format as a string.

Returns:

Hyperparameters in raw format.

Return type:

str

property is_model_benchmark_enabled : bool

Checks if model benchmarking is enabled based on application options and GUI settings.

Returns:

True if model benchmarking is enabled, False otherwise.

Return type:

bool

property is_multi_gpu : bool

Returns True if multi-GPU mode is enabled.

Returns:

True if multi-GPU is enabled.

Return type:

bool

property model_info : dict

Returns a selected row from the models table in dict format.

Returns:

Model configuration dict.

Return type:

dict

property model_name : str

Returns the name of the model.

Returns:

Model name.

Return type:

str

property model_source : str

Return whether the model is pretrained or custom.

Returns:

Model source.

Return type:

str

property num_classes : int

Returns the number of selected classes for training.

Returns:

Number of selected classes.

Return type:

int

property num_tags : int

Returns the number of selected tags for training.

property progress_bar_main : supervisely.app.widgets.sly_tqdm.sly_tqdm.Progress

Returns the main progress bar widget.

Returns:

Main progress bar widget.

Return type:

Progress

property progress_bar_secondary : supervisely.app.widgets.sly_tqdm.sly_tqdm.Progress

Returns the secondary progress bar widget.

Returns:

Secondary progress bar widget.

Return type:

Progress

property project_id : int

Returns the ID of the project.

Returns:

Project ID.

Return type:

int

property project_info : supervisely.api.project_api.ProjectInfo

Returns ProjectInfo object, which contains information about the project.

Returns:

Project info.

Return type:

ProjectInfo

property project_meta : supervisely.project.project_meta.ProjectMeta

Returns the project metadata.

Returns:

Project metadata.

Return type:

ProjectMeta

property project_name : str

Returns the name of the project.

Returns:

Project name.

Return type:

str

property start

Decorator for the user-defined training function.

The decorated function is executed when the user clicks Start training in the GUI (or when start_in_thread() is used).

The function must return experiment_info dict with required keys:

  • model_name: str (name of the model used for training)

  • task_type: str (usually TaskType)

  • checkpoints: list[str] or path to directory (str) containing .pt/.pth files

  • best_checkpoint: str (file name; must be present in checkpoints)

Optional keys:

  • model_files: dict[str, str] mapping additional files that model requires for inference (e.g model config)

TrainApp handles data preparation and (after your function returns) validates outputs and uploads artifacts to Team Files.

Returns:

A decorator that registers a training function.

Return type:

Callable

property tags : list[str]

Returns the selected tags for training.

property task_type : supervisely.nn.task_type.TaskType

Returns the task type of the model.

Returns:

Task type.

Return type:

TaskType

property team_id : int

Returns the ID of the team.

Returns:

Team ID.

Return type:

int

property workspace_id : int

Returns the ID of the workspace.

Returns:

Workspace ID.

Return type:

int