Skip to content

Commit d4d4f28

Browse files
authored
Revert "Replace deprecated pytorch methods (#1814)" (#1841)
This reverts commit 3e4da5f.
1 parent 3e4da5f commit d4d4f28

File tree

147 files changed

+518
-520
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

147 files changed

+518
-520
lines changed

egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from model import Transducer
6868
from optim import Eden, Eve
6969
from torch import Tensor
70-
from torch.amp import GradScaler
70+
from torch.cuda.amp import GradScaler
7171
from torch.nn.parallel import DistributedDataParallel as DDP
7272
from torch.utils.tensorboard import SummaryWriter
7373

@@ -638,7 +638,7 @@ def train_one_epoch(
638638
params.batch_idx_train += 1
639639
batch_size = len(batch["supervisions"]["text"])
640640

641-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
641+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
642642
loss, loss_info = compute_loss(
643643
params=params,
644644
model=model,
@@ -843,7 +843,7 @@ def remove_short_and_long_utt(c: Cut):
843843
params=params,
844844
)
845845

846-
scaler = GradScaler("cuda", enabled=params.use_fp16)
846+
scaler = GradScaler(enabled=params.use_fp16)
847847
if checkpoints and "grad_scaler" in checkpoints:
848848
logging.info("Loading grad scaler state dict")
849849
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -912,7 +912,7 @@ def scan_pessimistic_batches_for_oom(
912912
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
913913
# (i.e. are not remembered by the decaying-average in adam), because
914914
# we want to avoid these params being subject to shrinkage in adam.
915-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
915+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
916916
loss, _ = compute_loss(
917917
params=params,
918918
model=model,

egs/aishell/ASR/pruned_transducer_stateless2/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from model import Transducer
6161
from optim import Eden, Eve
6262
from torch import Tensor
63-
from torch.amp import GradScaler
63+
from torch.cuda.amp import GradScaler
6464
from torch.nn.parallel import DistributedDataParallel as DDP
6565
from torch.utils.tensorboard import SummaryWriter
6666

@@ -688,7 +688,7 @@ def train_one_epoch(
688688
batch_size = len(batch["supervisions"]["text"])
689689

690690
try:
691-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
691+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
692692
loss, loss_info = compute_loss(
693693
params=params,
694694
model=model,
@@ -888,7 +888,7 @@ def run(rank, world_size, args):
888888
params=params,
889889
)
890890

891-
scaler = GradScaler("cuda", enabled=params.use_fp16)
891+
scaler = GradScaler(enabled=params.use_fp16)
892892
if checkpoints and "grad_scaler" in checkpoints:
893893
logging.info("Loading grad scaler state dict")
894894
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -989,7 +989,7 @@ def scan_pessimistic_batches_for_oom(
989989
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
990990
# (i.e. are not remembered by the decaying-average in adam), because
991991
# we want to avoid these params being subject to shrinkage in adam.
992-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
992+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
993993
loss, _ = compute_loss(
994994
params=params,
995995
model=model,

egs/aishell/ASR/pruned_transducer_stateless3/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def forward(
184184
lm = simple_lm_proj(decoder_out)
185185
am = simple_am_proj(encoder_out)
186186

187-
with torch.amp.autocast("cuda", enabled=False):
187+
with torch.cuda.amp.autocast(enabled=False):
188188
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
189189
lm=lm.float(),
190190
am=am.float(),
@@ -219,7 +219,7 @@ def forward(
219219
# prior to do_rnnt_pruning (this is an optimization for speed).
220220
logits = joiner(am_pruned, lm_pruned, project_input=False)
221221

222-
with torch.amp.autocast("cuda", enabled=False):
222+
with torch.cuda.amp.autocast(enabled=False):
223223
pruned_loss = k2.rnnt_loss_pruned(
224224
logits=logits.float(),
225225
symbols=y_padded,

egs/aishell/ASR/pruned_transducer_stateless3/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
from model import Transducer
8080
from optim import Eden, Eve
8181
from torch import Tensor
82-
from torch.amp import GradScaler
82+
from torch.cuda.amp import GradScaler
8383
from torch.nn.parallel import DistributedDataParallel as DDP
8484
from torch.utils.tensorboard import SummaryWriter
8585

@@ -797,7 +797,7 @@ def train_one_epoch(
797797
aishell = is_aishell(batch["supervisions"]["cut"][0])
798798

799799
try:
800-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
800+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
801801
loss, loss_info = compute_loss(
802802
params=params,
803803
model=model,
@@ -1096,7 +1096,7 @@ def run(rank, world_size, args):
10961096
params=params,
10971097
)
10981098

1099-
scaler = GradScaler("cuda", enabled=params.use_fp16)
1099+
scaler = GradScaler(enabled=params.use_fp16)
11001100
if checkpoints and "grad_scaler" in checkpoints:
11011101
logging.info("Loading grad scaler state dict")
11021102
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
12021202
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
12031203
# (i.e. are not remembered by the decaying-average in adam), because
12041204
# we want to avoid these params being subject to shrinkage in adam.
1205-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
1205+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
12061206
loss, _ = compute_loss(
12071207
params=params,
12081208
model=model,

egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from model import Transducer
7575
from optim import Eden, ScaledAdam
7676
from torch import Tensor
77-
from torch.amp import GradScaler
77+
from torch.cuda.amp import GradScaler
7878
from torch.nn.parallel import DistributedDataParallel as DDP
7979
from torch.utils.tensorboard import SummaryWriter
8080
from zipformer import Zipformer
@@ -812,7 +812,7 @@ def train_one_epoch(
812812
batch_size = len(batch["supervisions"]["text"])
813813

814814
try:
815-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
815+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
816816
loss, loss_info = compute_loss(
817817
params=params,
818818
model=model,
@@ -1107,7 +1107,7 @@ def remove_short_and_long_utt(c: Cut):
11071107
# params=params,
11081108
# )
11091109

1110-
scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
1110+
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
11111111
if checkpoints and "grad_scaler" in checkpoints:
11121112
logging.info("Loading grad scaler state dict")
11131113
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom(
12061206
for criterion, cuts in batches.items():
12071207
batch = train_dl.dataset[cuts]
12081208
try:
1209-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
1209+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
12101210
loss, _ = compute_loss(
12111211
params=params,
12121212
model=model,

egs/aishell/ASR/pruned_transducer_stateless7/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
from model import Transducer
7171
from optim import Eden, ScaledAdam
7272
from torch import Tensor
73-
from torch.amp import GradScaler
73+
from torch.cuda.amp import GradScaler
7474
from torch.nn.parallel import DistributedDataParallel as DDP
7575
from torch.utils.tensorboard import SummaryWriter
7676
from zipformer import Zipformer
@@ -809,7 +809,7 @@ def train_one_epoch(
809809
batch_size = len(batch["supervisions"]["text"])
810810

811811
try:
812-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
812+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
813813
loss, loss_info = compute_loss(
814814
params=params,
815815
model=model,
@@ -1107,7 +1107,7 @@ def remove_short_and_long_utt(c: Cut):
11071107
# params=params,
11081108
# )
11091109

1110-
scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
1110+
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
11111111
if checkpoints and "grad_scaler" in checkpoints:
11121112
logging.info("Loading grad scaler state dict")
11131113
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom(
12061206
for criterion, cuts in batches.items():
12071207
batch = train_dl.dataset[cuts]
12081208
try:
1209-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
1209+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
12101210
loss, _ = compute_loss(
12111211
params=params,
12121212
model=model,

egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from model import Transducer
6565
from optim import Eden, ScaledAdam
6666
from torch import Tensor
67-
from torch.amp import GradScaler
67+
from torch.cuda.amp import GradScaler
6868
from torch.nn.parallel import DistributedDataParallel as DDP
6969
from torch.utils.tensorboard import SummaryWriter
7070
from zipformer import Zipformer
@@ -802,7 +802,7 @@ def train_one_epoch(
802802
batch_size = len(batch["supervisions"]["text"])
803803

804804
try:
805-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
805+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
806806
loss, loss_info = compute_loss(
807807
params=params,
808808
model=model,
@@ -1102,7 +1102,7 @@ def tokenize_and_encode_text(c: Cut):
11021102
params=params,
11031103
)
11041104

1105-
scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
1105+
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
11061106
if checkpoints and "grad_scaler" in checkpoints:
11071107
logging.info("Loading grad scaler state dict")
11081108
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
12021202
for criterion, cuts in batches.items():
12031203
batch = train_dl.dataset[cuts]
12041204
try:
1205-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
1205+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
12061206
loss, _ = compute_loss(
12071207
params=params,
12081208
model=model,

egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from model import Transducer
6464
from optim import Eden, ScaledAdam
6565
from torch import Tensor
66-
from torch.amp import GradScaler
66+
from torch.cuda.amp import GradScaler
6767
from torch.nn.parallel import DistributedDataParallel as DDP
6868
from torch.utils.tensorboard import SummaryWriter
6969
from zipformer_for_ncnn_export_only import Zipformer
@@ -813,7 +813,7 @@ def train_one_epoch(
813813
batch_size = len(batch["supervisions"]["text"])
814814

815815
try:
816-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
816+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
817817
loss, loss_info = compute_loss(
818818
params=params,
819819
model=model,
@@ -1105,7 +1105,7 @@ def remove_short_and_long_utt(c: Cut):
11051105
params=params,
11061106
)
11071107

1108-
scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
1108+
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
11091109
if checkpoints and "grad_scaler" in checkpoints:
11101110
logging.info("Loading grad scaler state dict")
11111111
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1205,7 +1205,7 @@ def scan_pessimistic_batches_for_oom(
12051205
for criterion, cuts in batches.items():
12061206
batch = train_dl.dataset[cuts]
12071207
try:
1208-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
1208+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
12091209
loss, _ = compute_loss(
12101210
params=params,
12111211
model=model,

egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from model import Transducer
6464
from optim import Eden, ScaledAdam
6565
from torch import Tensor
66-
from torch.amp import GradScaler
66+
from torch.cuda.amp import GradScaler
6767
from torch.nn.parallel import DistributedDataParallel as DDP
6868
from torch.utils.tensorboard import SummaryWriter
6969
from zipformer import Zipformer
@@ -812,7 +812,7 @@ def train_one_epoch(
812812
batch_size = len(batch["supervisions"]["text"])
813813

814814
try:
815-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
815+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
816816
loss, loss_info = compute_loss(
817817
params=params,
818818
model=model,
@@ -1104,7 +1104,7 @@ def remove_short_and_long_utt(c: Cut):
11041104
# params=params,
11051105
# )
11061106

1107-
scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
1107+
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
11081108
if checkpoints and "grad_scaler" in checkpoints:
11091109
logging.info("Loading grad scaler state dict")
11101110
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
12021202
for criterion, cuts in batches.items():
12031203
batch = train_dl.dataset[cuts]
12041204
try:
1205-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
1205+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
12061206
loss, _ = compute_loss(
12071207
params=params,
12081208
model=model,

egs/aishell/ASR/whisper/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from lhotse.utils import fix_random_seed
6363
from optim import Eden, ScaledAdam
6464
from torch import Tensor
65-
from torch.amp import GradScaler
65+
from torch.cuda.amp import GradScaler
6666
from torch.nn.functional import pad as pad_tensor
6767
from torch.nn.parallel import DistributedDataParallel as DDP
6868
from torch.utils.tensorboard import SummaryWriter
@@ -514,7 +514,7 @@ def compute_validation_loss(
514514
tot_loss = MetricsTracker()
515515

516516
for batch_idx, batch in enumerate(valid_dl):
517-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
517+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
518518
loss, loss_info = compute_loss(
519519
params=params,
520520
tokenizer=tokenizer,
@@ -608,7 +608,7 @@ def train_one_epoch(
608608
)
609609

610610
try:
611-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
611+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
612612
loss, loss_info = compute_loss(
613613
params=params,
614614
tokenizer=tokenizer,
@@ -812,7 +812,7 @@ def run(rank, world_size, args):
812812
train_dl = aishell.train_dataloaders(aishell.train_cuts())
813813
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
814814

815-
scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
815+
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
816816
if checkpoints and "grad_scaler" in checkpoints:
817817
logging.info("Loading grad scaler state dict")
818818
scaler.load_state_dict(checkpoints["grad_scaler"])

egs/aishell/ASR/zipformer/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from scaling import ScheduledFloat
7272
from subsampling import Conv2dSubsampling
7373
from torch import Tensor
74-
from torch.amp import GradScaler
74+
from torch.cuda.amp import GradScaler
7575
from torch.nn.parallel import DistributedDataParallel as DDP
7676
from torch.utils.tensorboard import SummaryWriter
7777
from zipformer import Zipformer2
@@ -910,7 +910,7 @@ def save_bad_model(suffix: str = ""):
910910
batch_size = len(batch["supervisions"]["text"])
911911

912912
try:
913-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
913+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
914914
loss, loss_info = compute_loss(
915915
params=params,
916916
model=model,
@@ -1201,7 +1201,7 @@ def remove_short_and_long_utt(c: Cut):
12011201
params=params,
12021202
)
12031203

1204-
scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0)
1204+
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
12051205
if checkpoints and "grad_scaler" in checkpoints:
12061206
logging.info("Loading grad scaler state dict")
12071207
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1302,7 +1302,7 @@ def scan_pessimistic_batches_for_oom(
13021302
for criterion, cuts in batches.items():
13031303
batch = train_dl.dataset[cuts]
13041304
try:
1305-
with torch.amp.autocast("cuda", enabled=params.use_fp16):
1305+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
13061306
loss, _ = compute_loss(
13071307
params=params,
13081308
model=model,

0 commit comments

Comments
 (0)