Skip to content

Commit

Permalink
fix unet config & instiantion
Browse files Browse the repository at this point in the history
  • Loading branch information
Sylvain Le Groux committed Feb 11, 2025
1 parent 3b030a1 commit e8c6fda
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 127 deletions.
22 changes: 22 additions & 0 deletions config/model/image/tinyunetx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_target_: nimrod.models.unet.TinyUnetX
_partial_: true

nnet:
_target_: nimrod.models.unet.TinyUnet
n_features: [3, 32, 64, 128, 256, 512, 1024] # channel/feature expansion
pre_activation: true
# activation:
# _target_: torch.nn.ReLU
# _partial_: true
# leaky: 0.0

activation:
_target_: torch.nn.LeakyReLU
_partial_: true
negative_slope: 0.1


weight_initialization: true

# optimizer
# schedule
48 changes: 28 additions & 20 deletions nbs/models.unet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,29 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Seed set to 42\n",
"/user/s/slegroux/miniconda3/envs/nimrod/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py:42: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.Output is equivalent up to float precision.\n",
" warnings.warn(\n",
"Seed set to 42\n",
"Seed set to 42\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
Expand All @@ -32,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -60,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -86,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -186,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -320,7 +340,7 @@
}
],
"source": [
"cfg = OmegaConf.load('../config/model/image/unetx.yaml')\n",
"cfg = OmegaConf.load('../config/model/image/tinyunetx.yaml')\n",
"model = instantiate(cfg.nnet)\n",
"x = torch.randn(1, 3, 64, 64)\n",
"model(x).shape\n"
Expand All @@ -339,21 +359,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
432 changes: 325 additions & 107 deletions tutorials/super_resolution.ipynb

Large diffs are not rendered by default.

0 comments on commit e8c6fda

Please sign in to comment.