Skip to content

Commit 6ac9aa4

Browse files
committed
fix #4 type hints added
1 parent 6b4f908 commit 6ac9aa4

File tree

4 files changed

+23
-7
lines changed

4 files changed

+23
-7
lines changed

CHANGES.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# 0.1.2 / 2019-09-26
2+
3+
* add type hints
14

25
# 0.1.1 / 2019-03-23
36

panelctmc/panel_to_datalist.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11

22
from yearfrac import yearfrac
33
import numpy as np
4+
import datetime
45

56

6-
def panel_to_datalist(data, lastdate=None):
7+
def panel_to_datalist(data: np.ndarray,
8+
lastdate: datetime.datetime = None
9+
) -> list:
710
"""Transforms array/list to ctmc's internal data format
811
912
Parameters:

panelctmc/panelctmc_class.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
from sklearn.base import BaseEstimator
33
from .panelctmc_func import panelctmc
44
from ctmc import simulate
5+
import datetime
6+
import numpy as np
57

68

79
class PanelCtmc(BaseEstimator):
810
"""Continous Time Markov Chain for Panel Data, sklearn API class"""
911

10-
def __init__(self, mapping=None, lastdate=None,
11-
transintv=1.0, toltime=1e-8, debug=False):
12+
def __init__(self, mapping: list = None,
13+
lastdate: datetime.datetime = None,
14+
transintv: float = 1.0,
15+
toltime: float = 1e-8,
16+
debug: bool = False):
1217
self.mapping = mapping
1318
self.lastdate = lastdate
1419
self.transintv = transintv
@@ -20,7 +25,7 @@ def __init__(self, mapping=None, lastdate=None,
2025
self.statetime = None
2126
self.datalist = None
2227

23-
def fit(self, X, y=None):
28+
def fit(self, X: np.ndarray, y=None):
2429
(
2530
self.transmat,
2631
self.genmat,
@@ -35,5 +40,5 @@ def fit(self, X, y=None):
3540
debug=self.debug)
3641
return self
3742

38-
def predict(self, X, steps=1):
43+
def predict(self, X: np.ndarray, steps: int = 1) -> np.ndarray:
3944
return simulate(X, self.transmat, steps)

panelctmc/panelctmc_func.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
from ctmc import ctmc, datacorrection
77

88

9-
def panelctmc(paneldata, mapping, lastdate=None,
10-
transintv=1.0, toltime=1e-8, debug=True):
9+
def panelctmc(paneldata: np.ndarray,
10+
mapping: list,
11+
lastdate: datetime = None,
12+
transintv: float = 1.0,
13+
toltime: float = 1e-8,
14+
debug: bool = True) -> (np.ndarray, np.ndarray, np.ndarray,
15+
np.ndarray, list):
1116
# check if numpy array
1217
if not isinstance(paneldata, np.ndarray):
1318
raise Exception("'paneldata' is not a numpy array")

0 commit comments

Comments
 (0)