Skip to content

Commit

Permalink
Merge pull request #122 from ayasyrev/dev
Browse files Browse the repository at this point in the history
4.2
  • Loading branch information
ayasyrev authored Jan 8, 2024
2 parents c9f17d5 + d71e24f commit 32d49d9
Show file tree
Hide file tree
Showing 17 changed files with 44 additions and 72 deletions.
2 changes: 1 addition & 1 deletion noxfile_cov.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nox


@nox.session(python=["3.10"])
@nox.session(python=["3.11"])
def cov_tests(session: nox.Session) -> None:
args = session.posargs or ["--cov"]
session.install(".", "pytest", "pytest-cov", "coverage[toml]")
Expand Down
5 changes: 1 addition & 4 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
pytest
pytest-cov
coverage[toml]
flake8
nox
pytest-cov
5 changes: 5 additions & 0 deletions requirements_test_extra.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
coverage[toml]
black
flake8
nox
isort
7 changes: 5 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@ long_description_content_type = text/markdown
url = https://github.com/ayasyrev/model_constructor
license = apache2
classifiers =
Programming Language :: Python :: 3
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
License :: OSI Approved :: Apache Software License
Operating System :: OS Independent

[options]
package_dir =
= src
packages = find:
python_requires = >=3.7
python_requires = >=3.8, <3.12

[options.packages.find]
where = src
Expand Down
50 changes: 0 additions & 50 deletions setup_.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/model_constructor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .convmixer import ConvMixer
from .model_constructor import ModelConstructor, ModelCfg
from .model_constructor import ModelCfg, ModelConstructor
from .version import __version__
3 changes: 1 addition & 2 deletions src/model_constructor/activations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# forked from https://github.com/rwightman/pytorch-image-models/timm/models/layers/activations.py
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import Mish

from torch.nn import functional as F

__all__ = [
"mish",
Expand Down
8 changes: 6 additions & 2 deletions src/model_constructor/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,18 @@ def print_set_fields(self) -> None:
else:
print("Nothing changed")

def print_changed_fields(self, show_default: bool = False, separator: str = " | ") -> None:
def print_changed_fields(
self, show_default: bool = False, separator: str = " | "
) -> None:
"""Print fields changed at init."""
if self.changed_fields:
default_value = ""
print("Changed fields:")
for field in self.changed_fields:
if show_default:
default_value = f"{separator}{self._get_str(self.model_fields[field].default)}"
default_value = (
f"{separator}{self._get_str(self.model_fields[field].default)}"
)
print(f"{field}: {self._get_str_value(field)}{default_value}")
else:
print("Nothing changed")
4 changes: 2 additions & 2 deletions src/model_constructor/model_constructor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

from pydantic import field_validator
from pydantic_core.core_schema import FieldValidationInfo
Expand Down Expand Up @@ -32,7 +32,7 @@
}


nnModule = Union[Type[nn.Module], Callable[[], nn.Module]]
nnModule = Union[Type[nn.Module], Callable[[Any], nn.Module]]


class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
Expand Down
1 change: 1 addition & 0 deletions src/model_constructor/mxresnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Type

from torch import nn

from .xresnet import XResNet
Expand Down
2 changes: 1 addition & 1 deletion src/model_constructor/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.1"
__version__ = "0.4.2_dev"
9 changes: 7 additions & 2 deletions src/model_constructor/xresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from .blocks import BottleneckBlock
from .helpers import ListStrMod, nn_seq, ModSeq
from .helpers import ListStrMod, ModSeq, nn_seq
from .model_constructor import ModelCfg, ModelConstructor

__all__ = [
Expand Down Expand Up @@ -50,6 +50,11 @@ class XResNet34(XResNet):
layers: List[int] = [3, 4, 6, 3]


class XResNet50(XResNet34):
class XResNet26(XResNet):
block: Type[nn.Module] = BottleneckBlock
block_sizes: List[int] = [256, 512, 1024, 2048]
expansion: int = 4


class XResNet50(XResNet26):
layers: List[int] = [3, 4, 6, 3]
10 changes: 7 additions & 3 deletions src/model_constructor/yaresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from model_constructor.helpers import ModSeq, nn_seq

from .layers import ConvBnAct, get_act
from .model_constructor import ListStrMod, ModelConstructor, ModelCfg
from .model_constructor import ListStrMod, ModelCfg, ModelConstructor
from .xresnet import xresnet_stem


__all__ = [
"YaBasicBlock",
"YaBottleneckBlock",
Expand Down Expand Up @@ -216,6 +215,11 @@ class YaResNet34(YaResNet):
layers: List[int] = [3, 4, 6, 3]


class YaResNet50(YaResNet34):
class YaResNet26(YaResNet):
block: Type[nn.Module] = YaBottleneckBlock
block_sizes: List[int] = [256, 512, 1024, 2048]
expansion: int = 4


class YaResNet50(YaResNet26):
layers: List[int] = [3, 4, 6, 3]
2 changes: 1 addition & 1 deletion tests/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from torch import nn

from model_constructor.layers import SEModule, SimpleSelfAttention
from model_constructor.blocks import BasicBlock, BottleneckBlock
from model_constructor.layers import SEModule, SimpleSelfAttention
from model_constructor.yaresnet import YaBasicBlock, YaBottleneckBlock

from .parameters import ids_fn
Expand Down
4 changes: 3 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def test_cfg_repr_print(capsys: CaptureFixture[str]):
cfg.print_set_fields()
out = capsys.readouterr().out
assert out == "Nothing changed\n"
assert "name" in cfg.model_fields_set # pylint: disable=E1135:unsupported-membership-test
assert (
"name" in cfg.model_fields_set
) # pylint: disable=E1135:unsupported-membership-test
cfg = Cfg2(int_value=0)
cfg.print_set_fields()
out = capsys.readouterr().out
Expand Down
1 change: 1 addition & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Type

import pytest
import torch
from torch import nn
Expand Down
1 change: 1 addition & 0 deletions tests/test_models_universal_blocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Type

import pytest
import torch
from torch import nn
Expand Down

0 comments on commit 32d49d9

Please sign in to comment.