Lookup and instantiate classes with style.
from class_resolver import ClassResolver
from dataclasses import dataclass
class Base: pass
@dataclass
class A(Base):
name: str
@dataclass
class B(Base):
name: str
# Index
resolver = ClassResolver([A, B], base=Base)
# Lookup
assert A == resolver.lookup('A')
# Instantiate with a dictionary
assert A(name='hi') == resolver.make('A', {'name': 'hi'})
# Instantiate with kwargs
assert A(name='hi') == resolver.make('A', name='hi')
# A pre-instantiated class will simply be passed through
assert A(name='hi') == resolver.make(A(name='hi'))
Assume you've implemented a simple multi-layer perceptron in PyTorch:
from itertools import chain
from more_itertools import pairwise
from torch import nn
class MLP(nn.Sequential):
def __init__(self, dims: list[int]):
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
nn.ReLU(),
)
for in_features, out_features in pairwise(dims)
))
This MLP uses a hard-coded rectified linear unit as the non-linear activation
function between layers. We can generalize this MLP to use a variety of
non-linear activation functions by adding an argument to its __init__()
function like in:
from itertools import chain
from more_itertools import pairwise
from torch import nn
class MLP(nn.Sequential):
def __init__(self, dims: list[int], activation: str = "relu"):
if activation == "relu":
activation = nn.ReLU()
elif activation == "tanh":
activation = nn.Tanh()
elif activation == "hardtanh":
activation = nn.Hardtanh()
else:
raise KeyError(f"Unsupported activation: {activation}")
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation,
)
for in_features, out_features in pairwise(dims)
))
The first issue with this implementation is it relies on a hard-coded set of conditional statements and is therefore hard to extend. It can be improved by using a dictionary lookup:
from itertools import chain
from more_itertools import pairwise
from torch import nn
activation_lookup: dict[str, nn.Module] = {
"relu": nn.ReLU(),
"tanh": nn.Tanh(),
"hardtanh": nn.Hardtanh(),
}
class MLP(nn.Sequential):
def __init__(self, dims: list[int], activation: str = "relu"):
activation = activation_lookup[activation]
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation,
)
for in_features, out_features in pairwise(dims)
))
This approach is rigid because it requires pre-instantiation of the activations.
If we needed to vary the arguments to the nn.HardTanh
class, the previous
approach wouldn't work. We can change the implementation to lookup on the class
before instantiation then optionally pass some arguments:
from itertools import chain
from more_itertools import pairwise
from torch import nn
activation_lookup: dict[str, type[nn.Module]] = {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"hardtanh": nn.Hardtanh,
}
class MLP(nn.Sequential):
def __init__(
self,
dims: list[int],
activation: str = "relu",
activation_kwargs: None | dict[str, any] = None,
):
activation_cls = activation_lookup[activation]
activation = activation_cls(**(activation_kwargs or {}))
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation,
)
for in_features, out_features in pairwise(dims)
))
This is pretty good, but it still has a few issues:
- you have to manually maintain the
activation_lookup
dictionary, - you can't pass an instance or class through the
activation
keyword - you have to get the casing just right
- the default is hard-coded as a string, which means this has to get copied (error-prone) in any place that creates an MLP
- you have to re-write this logic for all of your classes
Enter the class_resolver
package, which takes care of all of these things
using the following:
from itertools import chain
from class_resolver import ClassResolver, Hint
from more_itertools import pairwise
from torch import nn
activation_resolver = ClassResolver(
[nn.ReLU, nn.Tanh, nn.Hardtanh],
base=nn.Module,
default=nn.ReLU,
)
class MLP(nn.Sequential):
def __init__(
self,
dims: list[int],
activation: Hint[nn.Module] = None, # Hint = Union[None, str, nn.Module, type[nn.Module]]
activation_kwargs: None | dict[str, any] = None,
):
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation_resolver.make(activation, activation_kwargs),
)
for in_features, out_features in pairwise(dims)
))
Because this is such a common pattern, we've made it available through contrib
module in class_resolver.contrib.torch
:
from itertools import chain
from class_resolver import Hint
from class_resolver.contrib.torch import activation_resolver
from more_itertools import pairwise
from torch import nn
class MLP(nn.Sequential):
def __init__(
self,
dims: list[int],
activation: Hint[nn.Module] = None,
activation_kwargs: None | dict[str, any] = None,
):
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation_resolver.make(activation, activation_kwargs),
)
for in_features, out_features in pairwise(dims)
))
Now, you can instantiate the MLP with any of the following:
MLP(dims=[10, 200, 40]) # uses default, which is ReLU
MLP(dims=[10, 200, 40], activation="relu") # uses lowercase
MLP(dims=[10, 200, 40], activation="ReLU") # uses stylized
MLP(dims=[10, 200, 40], activation=nn.ReLU) # uses class
MLP(dims=[10, 200, 40], activation=nn.ReLU()) # uses instance
MLP(dims=[10, 200, 40], activation="hardtanh", activation_kwargs={"min_val": 0.0, "max_value": 6.0}) # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh, activation_kwargs={"min_val": 0.0, "max_value": 6.0}) # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh(0.0, 6.0)) # uses instance
In practice, it makes sense to stick to using the strings in combination with hyper-parameter optimization libraries like Optuna.
The most recent release can be installed from PyPI with uv:
$ uv pip install class_resolver
or with pip:
$ python3 -m pip install class_resolver
The most recent code and data can be installed directly from GitHub with uv:
$ uv pip install git+https://github.com/cthoyt/class-resolver.git
or with pip:
$ python3 -m pip install git+https://github.com/cthoyt/class-resolver.git
Contributions, whether filing an issue, making a pull request, or forking, are appreciated. See CONTRIBUTING.md for more information on getting involved.
The code in this package is licensed under the MIT License.
This package was created with @audreyfeldroy's cookiecutter package using @cthoyt's cookiecutter-snekpack template.
See developer instructions
The final section of the README is for if you want to get involved by making a code contribution.
To install in development mode, use the following:
$ git clone git+https://github.com/cthoyt/class-resolver.git
$ cd class-resolver
$ uv pip install -e .
Alternatively, install using pip:
$ python3 -m pip install -e .
This project uses cruft
to keep boilerplate (i.e., configuration, contribution
guidelines, documentation configuration) up-to-date with the upstream
cookiecutter package. Install cruft with either uv tool install cruft
or
python3 -m pip install cruft
then run:
$ cruft update
More info on Cruft's update command is available here.
After cloning the repository and installing tox
with
uv tool install tox --with tox-uv
or python3 -m pip install tox tox-uv
, the
unit tests in the tests/
folder can be run reproducibly with:
$ tox -e py
Additionally, these tests are automatically re-run with each commit in a GitHub Action.
The documentation can be built locally using the following:
$ git clone git+https://github.com/cthoyt/class-resolver.git
$ cd class-resolver
$ tox -e docs
$ open docs/build/html/index.html
The documentation automatically installs the package as well as the docs
extra
specified in the pyproject.toml
. sphinx
plugins like
texext
can be added there. Additionally, they need to be added to the
extensions
list in docs/source/conf.py
.
The documentation can be deployed to ReadTheDocs using
this guide. The
.readthedocs.yml
YAML file contains all the configuration
you'll need. You can also set up continuous integration on GitHub to check not
only that Sphinx can build the documentation in an isolated environment (i.e.,
with tox -e docs-test
) but also that
ReadTheDocs can build it too.
- Log in to ReadTheDocs with your GitHub account to install the integration at https://readthedocs.org/accounts/login/?next=/dashboard/
- Import your project by navigating to https://readthedocs.org/dashboard/import then clicking the plus icon next to your repository
- You can rename the repository on the next screen using a more stylized name (i.e., with spaces and capital letters)
- Click next, and you're good to go!
Zenodo is a long-term archival system that assigns a DOI to each release of your package.
- Log in to Zenodo via GitHub with this link: https://zenodo.org/oauth/login/github/?next=%2F. This brings you to a page that lists all of your organizations and asks you to approve installing the Zenodo app on GitHub. Click "grant" next to any organizations you want to enable the integration for, then click the big green "approve" button. This step only needs to be done once.
- Navigate to https://zenodo.org/account/settings/github/, which lists all of your GitHub repositories (both in your username and any organizations you enabled). Click the on/off toggle for any relevant repositories. When you make a new repository, you'll have to come back to this
After these steps, you're ready to go! After you make "release" on GitHub (steps for this are below), you can navigate to https://zenodo.org/account/settings/github/repository/cthoyt/class-resolver to see the DOI for the release and link to the Zenodo record for it.
You only have to do the following steps once.
- Register for an account on the Python Package Index (PyPI)
- Navigate to https://pypi.org/manage/account and make sure you have verified your email address. A verification email might not have been sent by default, so you might have to click the "options" dropdown next to your address to get to the "re-send verification email" button
- 2-Factor authentication is required for PyPI since the end of 2023 (see this blog post from PyPI). This means you have to first issue account recovery codes, then set up 2-factor authentication
- Issue an API token from https://pypi.org/manage/account/token
You have to do the following steps once per machine.
$ uv tool install keyring
$ keyring set https://upload.pypi.org/legacy/ __token__
$ keyring set https://test.pypi.org/legacy/ __token__
Note that this deprecates previous workflows using .pypirc
.
After installing the package in development mode and installing tox
with
uv tool install tox --with tox-uv
or python3 -m pip install tox tox-uv
, run
the following from the console:
$ tox -e finish
This script does the following:
- Uses bump-my-version to
switch the version number in the
pyproject.toml
,CITATION.cff
,src/class_resolver/version.py
, anddocs/source/conf.py
to not have the-dev
suffix - Packages the code in both a tar archive and a wheel using
uv build
- Uploads to PyPI using
uv publish
. - Push to GitHub. You'll need to make a release going with the commit where the version was bumped.
- Bump the version to the next patch. If you made big changes and want to bump
the version by minor, you can use
tox -e bumpversion -- minor
after.
- Navigate to https://github.com/cthoyt/class-resolver/releases/new to draft a new release
- Click the "Choose a Tag" dropdown and select the tag corresponding to the release you just made
- Click the "Generate Release Notes" button to get a quick outline of recent changes. Modify the title and description as you see fit
- Click the big green "Publish Release" button
This will trigger Zenodo to assign a DOI to your release as well.