Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer skorch models from CUDA to CPU for inference #1096

Open
lukethomrichardson opened this issue Feb 20, 2025 · 7 comments
Open

Transfer skorch models from CUDA to CPU for inference #1096

lukethomrichardson opened this issue Feb 20, 2025 · 7 comments

Comments

@lukethomrichardson
Copy link

lukethomrichardson commented Feb 20, 2025

Hi all, I am training skorch models locally in a CUDA-enabled torch environment, and, if possible, I would like to transfer the entirety of the model to CPU so that they can be registered and used for inference in a CPU-only environment. Is there a best method for accomplishing this?

I'm pretty new to skorch and deep learning so I'm not sure if this is even possible, but, if so, a skorch helper method for converting a model to CPU would be a nice-to-have feature.

Edit: Just noticed a very similar (old) issue that is still open at time of posting (#553). The conversation there didn't seem to completely resolve. Let me know if should post there or if reviving this topic here would be preferable.

@BenjaminBossan
Copy link
Collaborator

If you train a model on GPU, save it, then load it on a machine without GPU, it should already work and be automatically transferred to CPU. Please give this a try and tell us if you encounter problems.

The thread you cited is a bit different, as it is about changing the device within the same process.

@lukethomrichardson
Copy link
Author

lukethomrichardson commented Mar 20, 2025

@BenjaminBossan Thanks for the input. For testing purposes, I am saving the model in a CUDA-enabled torch environment and attempting to load in a CPU-only torch environment. Let me know if there is any other information I can provide.
skorch package: skorch==1.1.0
CUDA torch environment package: torch==2.5.1+cu124
CPU torch environment package: torch==2.5.1

Below is my code:

Saving

Train model in CUDA environment

import joblib
p = make_pipeline(*trainer.pre_X_preprocessor, *trainer.X_preprocessor, trainer.model)
joblib.dump(p, "skorch_cuda_to_cpu_model.joblib")

trainer.model is an instance of skorch NeuralNetClassifier. trainer.pre_X_preprocessor and trainer.X_preprocessor are sklearn pipelines.

Loading

Switch to CPU environment

import joblib
p = joblib.load("skorch_cuda_to_cpu_model.joblib")

Error Traceback

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 4
      1 import joblib
      2 # p = make_pipeline(*trainer.pre_X_preprocessor, *trainer.X_preprocessor, trainer.model)
      3 # joblib.dump(p, "skorch_cuda_to_cpu_model.joblib")
----> 4 p = joblib.load("skorch_cuda_to_cpu_model.joblib")
      5 p.predict(dfc)

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\joblib\numpy_pickle.py:658, in load(filename, mmap_mode)
    652             if isinstance(fobj, str):
    653                 # if the returned file object is a string, this means we
    654                 # try to load a pickle file generated with an version of
    655                 # Joblib so we load it with joblib compatibility function.
    656                 return load_compatibility(fobj)
--> 658             obj = _unpickle(fobj, filename, mmap_mode)
    659 return obj

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\joblib\numpy_pickle.py:577, in _unpickle(fobj, filename, mmap_mode)
    575 obj = None
    576 try:
--> 577     obj = unpickler.load()
    578     if unpickler.compat_mode:
    579         warnings.warn("The file '%s' has been generated with a "
    580                       "joblib version less than 0.10. "
    581                       "Please regenerate this pickle file."
    582                       % filename,
    583                       DeprecationWarning, stacklevel=3)

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\pickle.py:1213, in _Unpickler.load(self)
   1211             raise EOFError
   1212         assert isinstance(key, bytes_types)
-> 1213         dispatch[key[0]](self)
   1214 except _Stop as stopinst:
   1215     return stopinst.value

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\pickle.py:1590, in _Unpickler.load_reduce(self)
   1588 args = stack.pop()
   1589 func = stack[-1]
-> 1590 stack[-1] = func(*args)

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\torch\storage.py:520, in _load_from_bytes(b)
    519 def _load_from_bytes(b):
--> 520     return torch.load(io.BytesIO(b), weights_only=False)

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\torch\serialization.py:1384, in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1382     except pickle.UnpicklingError as e:
   1383         raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
-> 1384 return _legacy_load(
   1385     opened_file, map_location, pickle_module, **pickle_load_args
   1386 )

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\torch\serialization.py:1638, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
   1636 unpickler = UnpicklerWrapper(f, **pickle_load_args)
   1637 unpickler.persistent_load = persistent_load
-> 1638 result = unpickler.load()
   1640 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
   1642 if torch._guards.active_fake_mode() is None:

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\torch\serialization.py:1566, in _legacy_load.<locals>.persistent_load(saved_id)
   1564     obj = cast(Storage, torch.UntypedStorage(nbytes))
   1565     obj._torch_load_uninitialized = True
-> 1566     obj = restore_location(obj, location)
   1567 # TODO: Once we decide to break serialization FC, we can
   1568 # stop wrapping with TypedStorage
   1569 typed_storage = torch.storage.TypedStorage(
   1570     wrap_storage=obj, dtype=dtype, _internal=True
   1571 )

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\torch\serialization.py:601, in default_restore_location(storage, location)
    581 """
    582 Restores `storage` using a deserializer function registered for the `location`.
    583 
   (...)
    598        all matching ones return `None`.
    599 """
    600 for _, _, fn in _package_registry:
--> 601     result = fn(storage, location)
    602     if result is not None:
    603         return result

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\torch\serialization.py:539, in _deserialize(backend_name, obj, location)
    537     backend_name = torch._C._get_privateuse1_backend_name()
    538 if location.startswith(backend_name):
--> 539     device = _validate_device(location, backend_name)
    540     return obj.to(device=device)

File c:\Users\luker\anaconda3\envs\aml_torch_cpu\lib\site-packages\torch\serialization.py:508, in _validate_device(location, backend_name)
    506     device_index = device.index if device.index else 0
    507 if hasattr(device_module, "is_available") and not device_module.is_available():
--> 508     raise RuntimeError(
    509         f"Attempting to deserialize object on a {backend_name.upper()} "
    510         f"device but torch.{backend_name}.is_available() is False. "
    511         "If you are running on a CPU-only machine, "
    512         "please use torch.load with map_location=torch.device('cpu') "
    513         "to map your storages to the CPU."
    514     )
    515 if hasattr(device_module, "device_count"):
    516     device_count = device_module.device_count()

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Edit: I also tested with standard pickle as shown in the documentation and ran into the same error.

@BenjaminBossan
Copy link
Collaborator

Hmm, I can't reproduce this. Here is the script that I used:

import pickle
import sys

import numpy as np
import torch
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier


path = "/tmp/model.pkl"


class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super().__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.softmax(self.output(X))
        return X


def get_data():
    X, y = make_classification(1000, 20, n_informative=10, random_state=0)
    X = X.astype(np.float32)
    y = y.astype(np.int64)
    return X, y


def save():  # with cuda
    assert torch.cuda.is_available()

    X, y = get_data()

    net = NeuralNetClassifier(
        MyModule,
        max_epochs=10,
        lr=0.1,
        # Shuffle training data on each epoch
        iterator_train__shuffle=True,
        device="cuda",
    )
    net.fit(X, y)

    with open(path, "wb") as f:
        pickle.dump(net, f)


def load():  # without cuda
    assert not torch.cuda.is_available()

    with open(path, "rb") as f:
        net = pickle.load(f)

    X, y = get_data()
    y_pred = net.predict(X)
    print(f"accuracy: {(y==y_pred).mean()}")


if __name__ == "__main__":
    if sys.argv[1] == "save":
        save()
    elif sys.argv[1] == "load":
        load()
    else:
        raise ValueError

The script can be called as python script.py save to save a model that is trained with CUDA. Next, call CUDA_VISIBLE_DEVICES='' python script.py load (or transfer the file to a machine w/o GPU) and it should load the model on CPU. For me this works and I get these warnings as is expected:

/home/name/work/skorch/skorch/utils.py:569: DeviceWarning: Requested to load data to CUDA but no CUDA devices are available. Loading on device "cpu" instead.
  warnings.warn(
/home/name/work/skorch/skorch/net.py:2586: DeviceWarning: Setting self.device = cpu since the requested device (cuda) is not available.
  warnings.warn(

Could you please try if you can reproduce this? If yes, there must be something else going on in your script.

Note that you can also try the approach described here to save, for instance, only the model weights, so basically what you would do when using torch.save(state_dict, f) and state_dict = torch.load(f).

@lukethomrichardson
Copy link
Author

lukethomrichardson commented Mar 22, 2025

@BenjaminBossan Thanks for the test script -- that worked, and I got the warnings as expected. Below is a distilled test script that produces the error on my end. I'm guessing it has something to do with the torch modules I'm using? Other than those, I'm not sure what might be causing the issue.

import pickle
from skorch import NeuralNetClassifier
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys

class ResNet1D(nn.Module):
    """
    ResNet architecture for 1D spectral classification
    """

    def __init__(self, block, layers, num_classes=2, input_channels=1, reduction=16):
        super(ResNet1D, self).__init__()
        self.in_channels = 64

        # Initial convolution layer
        self.conv1 = nn.Conv1d(
            input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        # ResNet layers
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, reduction=reduction
        )
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, reduction=reduction
        )
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, reduction=reduction
        )

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, channels, blocks, stride=1, reduction=16):
        downsample = None

        # Downsample if stride is not 1 or input/output channels differ
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv1d(
                    self.in_channels,
                    channels * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm1d(channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, channels, stride, downsample, reduction))
        self.in_channels = channels * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.in_channels, channels, reduction=reduction))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        return x

class SqueezeExcitationBlock(nn.Module):
    """
    Squeeze-and-Excitation (SE) Block for channel-wise attention

    Parameters:
    -----------
    channel : int
        Number of input channels
    reduction : int, optional (default=16)
        Reduction ratio for the bottleneck
    """

    def __init__(self, channel, reduction=16):
        super(SqueezeExcitationBlock, self).__init__()

        # Squeeze operation (Global Average Pooling)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)

        # Excitation operation (Channel attention mechanism)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # Input shape: [batch_size, channels, length]
        batch_size, channels, _ = x.size()

        # Squeeze - Global Average Pooling
        y = self.avg_pool(x).view(batch_size, channels)

        # Excitation - Channel attention
        y = self.fc(y).view(batch_size, channels, 1)

        # Scale input features
        return x * y.expand_as(x)


class BasicResidualBlock(nn.Module):
    """
    Basic residual block for ResNet architecture
    Supports 1D convolutions for spectral data
    """

    expansion = 1

    def __init__(
        self, in_channels, out_channels, stride=1, downsample=None, reduction=16
    ):
        super(BasicResidualBlock, self).__init__()
        # First convolution layer
        self.conv1 = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm1d(out_channels)

        # Second convolution layer
        self.conv2 = nn.Conv1d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm1d(out_channels)

        # Squeeze-Excitation block
        self.se = SqueezeExcitationBlock(out_channels, reduction)

        # Downsampling layer for matching dimensions
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        # First conv path
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # Apply Squeeze-Excitation
        out = self.se(out)

        # Downsampling if needed
        if self.downsample is not None:
            identity = self.downsample(x)

        # Residual connection
        out += identity
        out = F.relu(out)
        return out
    
file = "skorch_cuda_to_cpu_model.pkl"

from sklearn.datasets import make_classification
def get_data():
    X, y = make_classification(1000, 534, n_informative=500, random_state=0)
    X = X.astype(np.float32)
    y = y.astype(np.int64)
    return X, y

def save():
    assert torch.cuda.is_available()

    MyModule = ResNet1D(
        BasicResidualBlock,
        layers=[2,2,2,2],
        num_classes=2,
        input_channels=1,
        )
    
    net = NeuralNetClassifier(
        MyModule,
        max_epochs=5,
        lr=0.1,
        iterator_train__shuffle=True,
        device="cuda",
        criterion=torch.nn.CrossEntropyLoss,)
    
    X, y = get_data()
    X = X.reshape(X.shape[0], 1, X.shape[1])
    net.fit(X, y)
    
    with open(file, "wb") as f:
        pickle.dump(net, f)
    
        
def load():
    assert not torch.cuda.is_available()
    
    with open(file, "rb") as f:
        net = pickle.load(f)
    print(net.device)
    
    
if __name__ == "__main__":
    if sys.argv[1] == "save":
        save()
    elif sys.argv[1] == "load":
        load()
    else:
        raise ValueError

@BenjaminBossan
Copy link
Collaborator

Thanks for the reproducer. I could verify that this fails at loading the model on a CPU machine. I tried to debug a little bit and it appears that when loading, torch already raises the error about missing CUDA support before the skorch code even runs, which means the skorch logic to load on CPU is not applied in time.

I don't have an idea why your code would trigger this but my example wouldn't. However, there can still be workarounds. One way is to move the whole model to CPU before pickling it, i.e. calling net.module.cpu() beforehand. If you have enough memory to do this, that would be the easiest solution. Otherwise, you can try the save_params/load_params approach I described above.

@lukethomrichardson
Copy link
Author

lukethomrichardson commented Mar 24, 2025

@BenjaminBossan That worked to fix the issue for the reproducer test script. I went on to test the workaround in my development environment, and I still encountered the error. After a bit of debugging, I found that error seems to be caused by the combination of using the LRScheduler callback and various optimizers. I tested a few different LRScheduler policies and that didn't seem to affect the issue. I tested SGD, Adam, Adamax, AdamW, Adafactor, Adagrad, and LBFGS, and the only one to load properly in a CPU torch environment was the model with SGD. Interestingly, the models would load successfully in a CPU torch environment regardless of the optimizer if the LRScheduler callback was removed. Below is a script that produces the error with the Adam optimizer.
Assuming I can get by with the SGD optimizer without any trouble, my issue is tentatively resolved, but just wanted to put this on your radar.

import pickle
from skorch import NeuralNetClassifier
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
from skorch.callbacks import LRScheduler, EarlyStopping

class ResNet1D(nn.Module):
    """
    ResNet architecture for 1D spectral classification
    """

    def __init__(self, block, layers, num_classes=2, input_channels=1, reduction=16):
        super(ResNet1D, self).__init__()
        # num_classes = 1 if num_classes == "binary" else num_classes
        self.in_channels = 64

        # Initial convolution layer
        self.conv1 = nn.Conv1d(
            input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        # ResNet layers
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, reduction=reduction
        )
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, reduction=reduction
        )
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, reduction=reduction
        )

        # Global average pooling and fully connected layer
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, channels, blocks, stride=1, reduction=16):
        downsample = None

        # Downsample if stride is not 1 or input/output channels differ
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv1d(
                    self.in_channels,
                    channels * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm1d(channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, channels, stride, downsample, reduction))
        self.in_channels = channels * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.in_channels, channels, reduction=reduction))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        return x

class SqueezeExcitationBlock(nn.Module):
    """
    Squeeze-and-Excitation (SE) Block for channel-wise attention

    Parameters:
    -----------
    channel : int
        Number of input channels
    reduction : int, optional (default=16)
        Reduction ratio for the bottleneck
    """

    def __init__(self, channel, reduction=16):
        super(SqueezeExcitationBlock, self).__init__()

        # Squeeze operation (Global Average Pooling)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)

        # Excitation operation (Channel attention mechanism)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # Input shape: [batch_size, channels, length]
        batch_size, channels, _ = x.size()

        # Squeeze - Global Average Pooling
        y = self.avg_pool(x).view(batch_size, channels)

        # Excitation - Channel attention
        y = self.fc(y).view(batch_size, channels, 1)

        # Scale input features
        return x * y.expand_as(x)

class BasicResidualBlock(nn.Module):
    """
    Basic residual block for ResNet architecture
    Supports 1D convolutions for spectral data
    """

    expansion = 1

    def __init__(
        self, in_channels, out_channels, stride=1, downsample=None, reduction=16
    ):
        super(BasicResidualBlock, self).__init__()
        # First convolution layer
        self.conv1 = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm1d(out_channels)

        # Second convolution layer
        self.conv2 = nn.Conv1d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm1d(out_channels)

        # Squeeze-Excitation block
        self.se = SqueezeExcitationBlock(out_channels, reduction)

        # Downsampling layer for matching dimensions
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        # First conv path
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # Apply Squeeze-Excitation
        out = self.se(out)

        # Downsampling if needed
        if self.downsample is not None:
            identity = self.downsample(x)

        # Residual connection
        out += identity
        out = F.relu(out)
        return out
    
file = "skorch_cuda_to_cpu_model.pkl"

from sklearn.datasets import make_classification
def get_data():
    X, y = make_classification(1000, 534, n_informative=500, random_state=0)
    X = X.astype(np.float32)
    y = y.astype(np.int64)
    return X, y

def save():
    assert torch.cuda.is_available()
    MyModule = ResNet1D(
        BasicResidualBlock,
        layers=[2,2,2,2],
        num_classes=2,
        input_channels=1,
        )
    lr_scheduler = LRScheduler(
        policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
        monitor='valid_loss',
        patience=10,
        factor=0.9,
        min_lr=1e-8,
    )
    net = NeuralNetClassifier(
        MyModule,
        max_epochs=5,
        lr=0.1,
        iterator_train__shuffle=True,
        device="cuda",
        criterion=torch.nn.CrossEntropyLoss,
        callbacks=[lr_scheduler],
        optimizer=torch.optim.Adam,
    )
    
    X, y = get_data()
    X = X.reshape(X.shape[0], 1, X.shape[1])
    net.fit(X, y)
    net.module.cpu()
    
    with open(file, "wb") as f:
        pickle.dump(net, f)
    
        
def load():
    assert not torch.cuda.is_available()
    
    with open(file, "rb") as f:
        net = pickle.load(f)
    print(net.device)
    
    
if __name__ == "__main__":
    if sys.argv[1] == "save":
        save()
    elif sys.argv[1] == "load":
        load()
    else:
        raise ValueError

@BenjaminBossan
Copy link
Collaborator

Thanks for providing further information. Without digging deeper: When pickling, skorch checks attributes with a CUDA-dependency, pops them from the pickle state, and saves them in a way that allows us to later load them without CUDA. Optimizers such as Adam store gradient states (mean, var), which are tensors that could be on CUDA. This is okay, since we treat the optimizer as a CUDA-dependent attribute. However, I suspect that the learning rate scheduler has a reference to the optimizer. Therefore, when the whole net, and thus the learning rate scheduler, is pickled, we still retain a reference to the CUDA tensors from the optimizer states.

As to workarounds, as you mentioned, you could use an optimizer that has no optimizer states, like SGD. However, this could lead to lower performance. If, after training, you don't need to continue training, you actually don't need the learning rate scheduler and could thus remove it from the net.callbacks_ before pickling. This should get rid of the reference (maybe call gc.collect() to be absolutely sure).

If, after training, you only need the model for inference, skorch also provides a method to get rid of all attributes that are not needed for inference: https://skorch.readthedocs.io/en/stable/net.html#skorch.net.NeuralNet.trim_for_prediction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants