Skip to content

Commit de54312

Browse files
author
Alexander März
committed
Added MLP as output mapping layer
1 parent 9bf8cc8 commit de54312

File tree

5 files changed

+98
-10
lines changed

5 files changed

+98
-10
lines changed

LagTST-hyperparameter_tuning.ipynb

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
" d_model=params[\"d_model\"],\n",
7777
" dim_feedforward=params[\"dim\"],\n",
7878
" batch_size=params[\"batch_size\"],\n",
79+
" patch_reverse_mapping_layer=\"mlp\",\n",
7980
" num_batches_per_epoch=100,\n",
8081
" trainer_kwargs=dict(accelerator=\"gpu\", max_epochs=30),\n",
8182
" )\n",

MlpTSMixer-hyperparameter_tuning.ipynb

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
" dim=params[\"dim\"],\n",
8282
" batch_size=params[\"batch_size\"],\n",
8383
" num_batches_per_epoch=100,\n",
84+
" patch_reverse_mapping_layer=\"mlp\",\n",
8485
" trainer_kwargs=dict(accelerator=\"cuda\", max_epochs=30),\n",
8586
" )\n",
8687
" predictor = estimator.train(\n",

MlpTSMixer/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
expansion_factor_token: float = 0.5,
114114
expansion_factor: int = 4,
115115
ablation: bool = False,
116-
patch_reverse_mapping_layer: str = "pooling",
116+
patch_reverse_mapping_layer: str = "mlp",
117117
pooling_type: str = "max",
118118
scaling: Optional[str] = "mean",
119119
num_feat_dynamic_real: int = 0,

TsT/estimator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def __init__(
114114
dropout=0.1,
115115
activation="relu",
116116
norm_first: bool = False,
117-
max_pool: bool = False,
117+
patch_reverse_mapping_layer: str = "mlp",
118+
pooling_type: str = "max",
118119
scaling: Optional[str] = "mean",
119120
num_feat_dynamic_real: int = 0,
120121
num_feat_static_cat: int = 0,
@@ -165,7 +166,8 @@ def __init__(
165166
self.dropout = dropout
166167
self.activation = activation
167168
self.norm_first = norm_first
168-
self.max_pool = max_pool
169+
self.patch_reverse_mapping_layer = patch_reverse_mapping_layer
170+
self.pooling_type = pooling_type
169171
self.lr = lr
170172
self.weight_decay = weight_decay
171173
self.distr_output = distr_output
@@ -262,7 +264,8 @@ def create_lightning_module(self) -> pl.LightningModule:
262264
"scaling": self.scaling,
263265
"distr_output": self.distr_output,
264266
"num_parallel_samples": self.num_parallel_samples,
265-
"max_pool": self.max_pool,
267+
"patch_reverse_mapping_layer": self.patch_reverse_mapping_layer,
268+
"pooling_type": self.pooling_type,
266269
},
267270
)
268271

TsT/module.py

+89-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import torch
1818
from torch import nn
19+
from einops.layers.torch import Rearrange
1920

2021
from gluonts.core.component import validated
2122
from gluonts.model import Input, InputSpec
@@ -32,6 +33,83 @@ def forward(self, x):
3233
return self.fn(x) + x
3334

3435

36+
class MLPPatchMap(nn.Module):
37+
"""
38+
Module implementing MLPMap for the reverse mapping of the patch-tensor.
39+
40+
Parameters
41+
----------
42+
patch_size : Tuple[int, int]
43+
Patch size.
44+
context_length : int
45+
Context length.
46+
prediction_length : int
47+
Number of time points to predict.
48+
input_size : int
49+
Input size.
50+
51+
Returns
52+
-------
53+
x : torch.Tensor
54+
"""
55+
56+
def __init__(
57+
self,
58+
patch_size: int,
59+
context_length: int,
60+
prediction_length: int,
61+
input_size: int,
62+
):
63+
super().__init__()
64+
p1 = int(context_length / patch_size[0])
65+
p2 = int(input_size / patch_size[1])
66+
self.fc = nn.Sequential(
67+
Rearrange("b c w h -> b c (w h)"),
68+
nn.Linear(p1 * p2, prediction_length * input_size),
69+
Rearrange("b c (w h) -> b c w h", w=prediction_length, h=input_size),
70+
)
71+
72+
def forward(self, x):
73+
x = self.fc(x)
74+
return x
75+
76+
77+
def RevMapLayer(
78+
layer_type: str,
79+
pooling_type: str,
80+
dim: int,
81+
patch_size: int,
82+
context_length: int,
83+
prediction_length: int,
84+
input_size: int,
85+
):
86+
"""
87+
Returns the mapping layer for the reverse mapping of the patch-tensor to [b nf h ns].
88+
89+
:argument
90+
layer_type: str = "pooling" or "mlp" or "conv1d"
91+
pooling_type: str = "max" or "mean"
92+
dim: int = dimension of the embeddings
93+
patch_size: Tuple[int, int] = patch size
94+
prediction_length: int = prediction length
95+
context_length: int = context length
96+
input_size: int = input size
97+
98+
:returns
99+
nn.Module = mapping layer
100+
101+
"""
102+
if layer_type == "pooling":
103+
if pooling_type == "max":
104+
return nn.AdaptiveMaxPool2d((prediction_length, input_size))
105+
elif pooling_type == "mean":
106+
return nn.AdaptiveAvgPool2d((prediction_length, input_size))
107+
elif layer_type == "mlp":
108+
return MLPPatchMap(patch_size, context_length, prediction_length, input_size)
109+
else:
110+
raise ValueError("Invalid layer type: {}".format(layer_type))
111+
112+
35113
class TsTModel(nn.Module):
36114
"""
37115
Module implementing TsT for forecasting.
@@ -66,7 +144,8 @@ def __init__(
66144
dropout: float,
67145
activation: str,
68146
norm_first: bool,
69-
max_pool: bool = False,
147+
patch_reverse_mapping_layer: str = "pooling",
148+
pooling_type: str = "max",
70149
num_feat_dynamic_real: int = 0,
71150
num_feat_static_real: int = 0,
72151
num_feat_static_cat: int = 0,
@@ -120,10 +199,14 @@ def __init__(
120199
encoder_norm = nn.LayerNorm(dim, eps=layer_norm_eps)
121200
self.encoder = nn.TransformerEncoder(encoder_layer, depth, encoder_norm)
122201

123-
self.pool = (
124-
nn.AdaptiveMaxPool2d((self.prediction_length, self.input_size))
125-
if max_pool
126-
else nn.AdaptiveAvgPool2d((self.prediction_length, self.input_size))
202+
self.rev_map_layer = RevMapLayer(
203+
layer_type=patch_reverse_mapping_layer,
204+
pooling_type=pooling_type,
205+
dim=dim,
206+
patch_size=patch_size,
207+
prediction_length=self.prediction_length,
208+
context_length=self.context_length,
209+
input_size=self.input_size,
127210
)
128211

129212
self.args_proj = self.distr_output.get_args_proj(
@@ -198,7 +281,7 @@ def forward(
198281
embed_pos = self.positional_encoding(x.size())
199282
enc_out = self.encoder(x + embed_pos)
200283

201-
nn_out = self.pool(enc_out.permute(0, 2, 1).reshape(B, C, H, W))
284+
nn_out = self.rev_map_layer(enc_out.permute(0, 2, 1).reshape(B, C, H, W))
202285

203286
# [B, F, C, D] -> [B, F, P, D]
204287

0 commit comments

Comments
 (0)