Skip to content

Commit a22d110

Browse files
authored
Merge pull request #757 from lizhuoq/main
val criterion cal from cpu to gpu
2 parents 498e166 + bdb0678 commit a22d110

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

exp/exp_anomaly_detection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def vali(self, vali_data, vali_loader, criterion):
5151

5252
f_dim = -1 if self.args.features == 'MS' else 0
5353
outputs = outputs[:, :, f_dim:]
54-
pred = outputs.detach().cpu()
55-
true = batch_x.detach().cpu()
54+
pred = outputs.detach()
55+
true = batch_x.detach()
5656

5757
loss = criterion(pred, true)
58-
total_loss.append(loss)
58+
total_loss.append(loss.item())
5959
total_loss = np.average(total_loss)
6060
self.model.train()
6161
return total_loss

exp/exp_classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def vali(self, vali_data, vali_loader, criterion):
5757

5858
outputs = self.model(batch_x, padding_mask, None, None)
5959

60-
pred = outputs.detach().cpu()
61-
loss = criterion(pred, label.long().squeeze().cpu())
62-
total_loss.append(loss)
60+
pred = outputs.detach()
61+
loss = criterion(pred, label.long().squeeze())
62+
total_loss.append(loss.item())
6363

6464
preds.append(outputs.detach())
6565
trues.append(label)

exp/exp_imputation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ def vali(self, vali_data, vali_loader, criterion):
6565
batch_x = batch_x[:, :, f_dim:]
6666
mask = mask[:, :, f_dim:]
6767

68-
pred = outputs.detach().cpu()
69-
true = batch_x.detach().cpu()
70-
mask = mask.detach().cpu()
68+
pred = outputs.detach()
69+
true = batch_x.detach()
70+
mask = mask.detach()
7171

7272
loss = criterion(pred[mask == 0], true[mask == 0])
73-
total_loss.append(loss)
73+
total_loss.append(loss.item())
7474
total_loss = np.average(total_loss)
7575
self.model.train()
7676
return total_loss

exp/exp_long_term_forecasting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def vali(self, vali_data, vali_loader, criterion):
6363
outputs = outputs[:, -self.args.pred_len:, f_dim:]
6464
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
6565

66-
pred = outputs.detach().cpu()
67-
true = batch_y.detach().cpu()
66+
pred = outputs.detach()
67+
true = batch_y.detach()
6868

6969
loss = criterion(pred, true)
7070

71-
total_loss.append(loss)
71+
total_loss.append(loss.item())
7272
total_loss = np.average(total_loss)
7373
self.model.train()
7474
return total_loss

0 commit comments

Comments
 (0)