Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from typing import Any, cast, Dict, Generic, Iterator, TypeVar, Union

import torch
from pyre_extensions import none_throws
from torchtnt.framework._unit_utils import (
_find_optimizers_for_module,
_step_requires_iterator,
)

from torchtnt.framework.state import State
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
Expand Down Expand Up @@ -312,6 +312,7 @@ def on_train_epoch_end(self, state: State) -> None:
def __init__(self) -> None:
super().__init__()
self.train_progress = Progress()
self.first_train_batch: TTrainData | None = None

def on_train_start(self, state: State) -> None:
"""Hook called before training starts.
Expand All @@ -329,6 +330,14 @@ def on_train_epoch_start(self, state: State) -> None:
"""
pass

@property
def first_train_batch(self) -> TTrainData:
return none_throws(self.first_train_batch)

@first_train_batch.setter
def first_train_batch(self, data: TTrainData) -> None:
self.first_train_batch = data

@abstractmethod
# pyre-fixme[3]: Return annotation cannot be `Any`.
def train_step(self, state: State, data: TTrainData) -> Any:
Expand Down
Loading