16
16
import numpy as np
17
17
import torch
18
18
from torch import nn
19
+ from einops .layers .torch import Rearrange
19
20
20
21
from gluonts .core .component import validated
21
22
from gluonts .model import Input , InputSpec
@@ -32,6 +33,83 @@ def forward(self, x):
32
33
return self .fn (x ) + x
33
34
34
35
36
+ class MLPPatchMap (nn .Module ):
37
+ """
38
+ Module implementing MLPMap for the reverse mapping of the patch-tensor.
39
+
40
+ Parameters
41
+ ----------
42
+ patch_size : Tuple[int, int]
43
+ Patch size.
44
+ context_length : int
45
+ Context length.
46
+ prediction_length : int
47
+ Number of time points to predict.
48
+ input_size : int
49
+ Input size.
50
+
51
+ Returns
52
+ -------
53
+ x : torch.Tensor
54
+ """
55
+
56
+ def __init__ (
57
+ self ,
58
+ patch_size : int ,
59
+ context_length : int ,
60
+ prediction_length : int ,
61
+ input_size : int ,
62
+ ):
63
+ super ().__init__ ()
64
+ p1 = int (context_length / patch_size [0 ])
65
+ p2 = int (input_size / patch_size [1 ])
66
+ self .fc = nn .Sequential (
67
+ Rearrange ("b c w h -> b c (w h)" ),
68
+ nn .Linear (p1 * p2 , prediction_length * input_size ),
69
+ Rearrange ("b c (w h) -> b c w h" , w = prediction_length , h = input_size ),
70
+ )
71
+
72
+ def forward (self , x ):
73
+ x = self .fc (x )
74
+ return x
75
+
76
+
77
+ def RevMapLayer (
78
+ layer_type : str ,
79
+ pooling_type : str ,
80
+ dim : int ,
81
+ patch_size : int ,
82
+ context_length : int ,
83
+ prediction_length : int ,
84
+ input_size : int ,
85
+ ):
86
+ """
87
+ Returns the mapping layer for the reverse mapping of the patch-tensor to [b nf h ns].
88
+
89
+ :argument
90
+ layer_type: str = "pooling" or "mlp" or "conv1d"
91
+ pooling_type: str = "max" or "mean"
92
+ dim: int = dimension of the embeddings
93
+ patch_size: Tuple[int, int] = patch size
94
+ prediction_length: int = prediction length
95
+ context_length: int = context length
96
+ input_size: int = input size
97
+
98
+ :returns
99
+ nn.Module = mapping layer
100
+
101
+ """
102
+ if layer_type == "pooling" :
103
+ if pooling_type == "max" :
104
+ return nn .AdaptiveMaxPool2d ((prediction_length , input_size ))
105
+ elif pooling_type == "mean" :
106
+ return nn .AdaptiveAvgPool2d ((prediction_length , input_size ))
107
+ elif layer_type == "mlp" :
108
+ return MLPPatchMap (patch_size , context_length , prediction_length , input_size )
109
+ else :
110
+ raise ValueError ("Invalid layer type: {}" .format (layer_type ))
111
+
112
+
35
113
class TsTModel (nn .Module ):
36
114
"""
37
115
Module implementing TsT for forecasting.
@@ -66,7 +144,8 @@ def __init__(
66
144
dropout : float ,
67
145
activation : str ,
68
146
norm_first : bool ,
69
- max_pool : bool = False ,
147
+ patch_reverse_mapping_layer : str = "pooling" ,
148
+ pooling_type : str = "max" ,
70
149
num_feat_dynamic_real : int = 0 ,
71
150
num_feat_static_real : int = 0 ,
72
151
num_feat_static_cat : int = 0 ,
@@ -120,10 +199,14 @@ def __init__(
120
199
encoder_norm = nn .LayerNorm (dim , eps = layer_norm_eps )
121
200
self .encoder = nn .TransformerEncoder (encoder_layer , depth , encoder_norm )
122
201
123
- self .pool = (
124
- nn .AdaptiveMaxPool2d ((self .prediction_length , self .input_size ))
125
- if max_pool
126
- else nn .AdaptiveAvgPool2d ((self .prediction_length , self .input_size ))
202
+ self .rev_map_layer = RevMapLayer (
203
+ layer_type = patch_reverse_mapping_layer ,
204
+ pooling_type = pooling_type ,
205
+ dim = dim ,
206
+ patch_size = patch_size ,
207
+ prediction_length = self .prediction_length ,
208
+ context_length = self .context_length ,
209
+ input_size = self .input_size ,
127
210
)
128
211
129
212
self .args_proj = self .distr_output .get_args_proj (
@@ -198,7 +281,7 @@ def forward(
198
281
embed_pos = self .positional_encoding (x .size ())
199
282
enc_out = self .encoder (x + embed_pos )
200
283
201
- nn_out = self .pool (enc_out .permute (0 , 2 , 1 ).reshape (B , C , H , W ))
284
+ nn_out = self .rev_map_layer (enc_out .permute (0 , 2 , 1 ).reshape (B , C , H , W ))
202
285
203
286
# [B, F, C, D] -> [B, F, P, D]
204
287
0 commit comments