-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Description
去掉.detach().cpu(),将criterion的计算放在GPU上:
原始:
def vali(self, vali_data, vali_loader, criterion):
total_loss = []
self.model.eval()
with torch.no_grad():
for i, (batch_x, _) in enumerate(vali_loader):
batch_x = batch_x.float().to(self.device)
outputs = self.model(batch_x, None, None, None)
f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, :, f_dim:]
pred = outputs.detach().cpu()
true = batch_x.detach().cpu()
loss = criterion(pred, true)
total_loss.append(loss)
total_loss = np.average(total_loss)
self.model.train()
return total_loss
改为:
def vali(self, vali_data, vali_loader, criterion):
.....
pred = outputs
true = batch_x
loss = criterion(pred, true)
total_loss.append(loss.cpu())
total_loss = np.average(total_loss)
self.model.train()
return total_loss
Metadata
Metadata
Assignees
Labels
No labels