Skip to content

Commit 252ad94

Browse files
committed
start
1 parent e2aba9c commit 252ad94

File tree

6 files changed

+128
-39
lines changed

6 files changed

+128
-39
lines changed

files_to_sync.txt

Whitespace-only changes.

minitorch/module.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
class Module:
7-
"""
8-
Modules form a tree that store parameters and other
7+
"""Modules form a tree that store parameters and other
98
submodules. They make up the basis of neural network stacks.
109
11-
Attributes:
10+
Attributes
11+
----------
1212
_modules : Storage of the child modules
1313
_parameters : Storage of the module's parameters
1414
training : Whether the module is in training mode or evaluation mode
@@ -25,46 +25,48 @@ def __init__(self) -> None:
2525
self.training = True
2626

2727
def modules(self) -> Sequence[Module]:
28-
"Return the direct child modules of this module."
28+
"""Return the direct child modules of this module."""
2929
m: Dict[str, Module] = self.__dict__["_modules"]
3030
return list(m.values())
3131

3232
def train(self) -> None:
33-
"Set the mode of this module and all descendent modules to `train`."
33+
"""Set the mode of this module and all descendent modules to `train`."""
3434
# TODO: Implement for Task 0.4.
3535
raise NotImplementedError("Need to implement for Task 0.4")
3636

3737
def eval(self) -> None:
38-
"Set the mode of this module and all descendent modules to `eval`."
38+
"""Set the mode of this module and all descendent modules to `eval`."""
3939
# TODO: Implement for Task 0.4.
4040
raise NotImplementedError("Need to implement for Task 0.4")
4141

4242
def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
43-
"""
44-
Collect all the parameters of this module and its descendents.
43+
"""Collect all the parameters of this module and its descendents.
4544
46-
47-
Returns:
45+
Returns
46+
-------
4847
The name and `Parameter` of each ancestor parameter.
48+
4949
"""
5050
# TODO: Implement for Task 0.4.
5151
raise NotImplementedError("Need to implement for Task 0.4")
5252

5353
def parameters(self) -> Sequence[Parameter]:
54-
"Enumerate over all the parameters of this module and its descendents."
54+
"""Enumerate over all the parameters of this module and its descendents."""
5555
# TODO: Implement for Task 0.4.
5656
raise NotImplementedError("Need to implement for Task 0.4")
5757

5858
def add_parameter(self, k: str, v: Any) -> Parameter:
59-
"""
60-
Manually add a parameter. Useful helper for scalar parameters.
59+
"""Manually add a parameter. Useful helper for scalar parameters.
6160
6261
Args:
62+
----
6363
k: Local name of the parameter.
6464
v: Value for the parameter.
6565
6666
Returns:
67+
-------
6768
Newly created parameter.
69+
6870
"""
6971
val = Parameter(v, k)
7072
self.__dict__["_parameters"][k] = val
@@ -118,8 +120,7 @@ def _addindent(s_: str, numSpaces: int) -> str:
118120

119121

120122
class Parameter:
121-
"""
122-
A Parameter is a special container stored in a `Module`.
123+
"""A Parameter is a special container stored in a `Module`.
123124
124125
It is designed to hold a `Variable`, but we allow it to hold
125126
any value for testing.
@@ -134,7 +135,7 @@ def __init__(self, x: Any, name: Optional[str] = None) -> None:
134135
self.value.name = self.name
135136

136137
def update(self, x: Any) -> None:
137-
"Update the parameter value."
138+
"""Update the parameter value."""
138139
self.value = x
139140
if hasattr(x, "requires_grad_"):
140141
self.value.requires_grad_(True)

minitorch/operators.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
"""
2-
Collection of the core mathematical operators used throughout the code base.
3-
"""
1+
"""Collection of the core mathematical operators used throughout the code base."""
42

53
import math
64

pyproject.toml

+99-5
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,16 @@ name = "minitorch"
77
version = "0.5"
88

99
[tool.pyright]
10-
include = ["minitorch","tests"]
11-
exclude = ["**/docs", "**/project", "**/mt_diagrams", "**/assignments"]
10+
include = ["**/minitorch"]
11+
exclude = [
12+
"**/docs",
13+
"**/docs/module1/**",
14+
"**/assignments",
15+
"**/project",
16+
"**/mt_diagrams",
17+
"**/.*",
18+
"*chainrule.py*",
19+
]
1220
venvPath = "."
1321
venv = ".venv"
1422
reportUnknownMemberType = "none"
@@ -21,6 +29,8 @@ reportUnusedExpression = "none"
2129
reportUnknownLambdaType = "none"
2230
reportIncompatibleMethodOverride = "none"
2331
reportPrivateUsage = "none"
32+
reportMissingParameterType = "error"
33+
2434

2535
[tool.pytest.ini_options]
2636
markers = [
@@ -50,7 +60,91 @@ markers = [
5060
"task4_3",
5161
"task4_4",
5262
]
63+
[tool.ruff]
64+
65+
exclude = [
66+
".git",
67+
"__pycache__",
68+
"**/docs/slides/*",
69+
"old,build",
70+
"dist",
71+
"**/project/**/*",
72+
"**/mt_diagrams/*",
73+
"**/minitorch/testing.py",
74+
"**/docs/**/*",
75+
]
76+
77+
ignore = [
78+
"ANN101",
79+
"ANN401",
80+
"N801",
81+
"E203",
82+
"E266",
83+
"E501",
84+
"E741",
85+
"N803",
86+
"N802",
87+
"N806",
88+
"D400",
89+
"D401",
90+
"D105",
91+
"D415",
92+
"D205",
93+
"D100",
94+
"D101",
95+
"D107",
96+
"D213",
97+
"ANN204",
98+
"ANN102",
99+
]
100+
select = ["D", "E", "F", "N", "ANN"]
101+
fixable = [
102+
"A",
103+
"B",
104+
"C",
105+
"D",
106+
"E",
107+
"F",
108+
"G",
109+
"I",
110+
"N",
111+
"Q",
112+
"S",
113+
"T",
114+
"W",
115+
"ANN",
116+
"ARG",
117+
"BLE",
118+
"COM",
119+
"DJ",
120+
"DTZ",
121+
"EM",
122+
"ERA",
123+
"EXE",
124+
"FBT",
125+
"ICN",
126+
"INP",
127+
"ISC",
128+
"NPY",
129+
"PD",
130+
"PGH",
131+
"PIE",
132+
"PL",
133+
"PT",
134+
"PTH",
135+
"PYI",
136+
"RET",
137+
"RSE",
138+
"RUF",
139+
"SIM",
140+
"SLF",
141+
"TCH",
142+
"TID",
143+
"TRY",
144+
"UP",
145+
"YTT",
146+
]
147+
unfixable = []
53148

54-
[tool.ruff.lint]
55-
ignore = ["N801", "E203", "E266", "E501", "E741", "N803", "N802", "N806"]
56-
exclude = [".git","__pycache__","docs/slides/*","old,build","dist"]
149+
[tool.ruff.extend-per-file-ignores]
150+
"tests/**/*.py" = ["D"]

tests/test_module.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self) -> None:
4444

4545
@pytest.mark.task0_4
4646
def test_stacked_demo() -> None:
47-
"Check that each of the properties match"
47+
"""Check that each of the properties match"""
4848
mod = ModuleA1()
4949
np = dict(mod.named_parameters())
5050

@@ -95,7 +95,7 @@ def __init__(self) -> None:
9595
@pytest.mark.task0_4
9696
@given(med_ints, med_ints)
9797
def test_module(size_a: int, size_b: int) -> None:
98-
"Check the properties of a single module"
98+
"""Check the properties of a single module"""
9999
module = Module2()
100100
module.eval()
101101
assert not module.training
@@ -116,7 +116,7 @@ def test_module(size_a: int, size_b: int) -> None:
116116
@pytest.mark.task0_4
117117
@given(med_ints, med_ints, small_floats)
118118
def test_stacked_module(size_a: int, size_b: int, val: float) -> None:
119-
"Check the properties of a stacked module"
119+
"""Check the properties of a stacked module"""
120120
module = Module1(size_a, size_b, val)
121121
module.eval()
122122
assert not module.training

tests/test_operators.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from hypothesis.strategies import lists
66

77
from minitorch import MathTest
8+
import minitorch
89
from minitorch.operators import (
910
add,
1011
addLists,
@@ -32,7 +33,7 @@
3233
@pytest.mark.task0_1
3334
@given(small_floats, small_floats)
3435
def test_same_as_python(x: float, y: float) -> None:
35-
"Check that the main operators all return the same value of the python version"
36+
"""Check that the main operators all return the same value of the python version"""
3637
assert_close(mul(x, y), x * y)
3738
assert_close(add(x, y), x + y)
3839
assert_close(neg(x), -x)
@@ -68,7 +69,7 @@ def test_id(a: float) -> None:
6869
@pytest.mark.task0_1
6970
@given(small_floats)
7071
def test_lt(a: float) -> None:
71-
"Check that a - 1.0 is always less than a"
72+
"""Check that a - 1.0 is always less than a"""
7273
assert lt(a - 1.0, a) == 1.0
7374
assert lt(a, a - 1.0) == 0.0
7475

@@ -113,15 +114,14 @@ def test_sigmoid(a: float) -> None:
113114
@pytest.mark.task0_2
114115
@given(small_floats, small_floats, small_floats)
115116
def test_transitive(a: float, b: float, c: float) -> None:
116-
"Test the transitive property of less-than (a < b and b < c implies a < c)"
117+
"""Test the transitive property of less-than (a < b and b < c implies a < c)"""
117118
# TODO: Implement for Task 0.2.
118119
raise NotImplementedError("Need to implement for Task 0.2")
119120

120121

121122
@pytest.mark.task0_2
122123
def test_symmetric() -> None:
123-
"""
124-
Write a test that ensures that :func:`minitorch.operators.mul` is symmetric, i.e.
124+
"""Write a test that ensures that :func:`minitorch.operators.mul` is symmetric, i.e.
125125
gives the same value regardless of the order of its input.
126126
"""
127127
# TODO: Implement for Task 0.2.
@@ -130,8 +130,7 @@ def test_symmetric() -> None:
130130

131131
@pytest.mark.task0_2
132132
def test_distribute() -> None:
133-
r"""
134-
Write a test that ensures that your operators distribute, i.e.
133+
r"""Write a test that ensures that your operators distribute, i.e.
135134
:math:`z \times (x + y) = z \times x + z \times y`
136135
"""
137136
# TODO: Implement for Task 0.2.
@@ -140,9 +139,7 @@ def test_distribute() -> None:
140139

141140
@pytest.mark.task0_2
142141
def test_other() -> None:
143-
"""
144-
Write a test that ensures some other property holds for your functions.
145-
"""
142+
"""Write a test that ensures some other property holds for your functions."""
146143
# TODO: Implement for Task 0.2.
147144
raise NotImplementedError("Need to implement for Task 0.2")
148145

@@ -168,8 +165,7 @@ def test_zip_with(a: float, b: float, c: float, d: float) -> None:
168165
lists(small_floats, min_size=5, max_size=5),
169166
)
170167
def test_sum_distribute(ls1: List[float], ls2: List[float]) -> None:
171-
"""
172-
Write a test that ensures that the sum of `ls1` plus the sum of `ls2`
168+
"""Write a test that ensures that the sum of `ls1` plus the sum of `ls2`
173169
is the same as the sum of each element of `ls1` plus each element of `ls2`.
174170
"""
175171
# TODO: Implement for Task 0.3.

0 commit comments

Comments
 (0)