Skip to content

Commit 9cf5cad

Browse files
authored
Extend chapter on feature importance results / robustness checks / outlook👨‍💻 (#405)
1 parent dd69da6 commit 9cf5cad

34 files changed

+1559
-449
lines changed

notebooks/4.0c-mb-feature-importances.ipynb

+435-37
Large diffs are not rendered by default.

notebooks/6.0a-mb-results-fttransformer.ipynb

+49-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
"execution_count": null,
66
"metadata": {
77
"colab_type": "text",
8-
"id": "view-in-github"
8+
"id": "view-in-github",
9+
"tags": []
910
},
1011
"outputs": [],
1112
"source": [
@@ -34,7 +35,9 @@
3435
{
3536
"cell_type": "code",
3637
"execution_count": null,
37-
"metadata": {},
38+
"metadata": {
39+
"tags": []
40+
},
3841
"outputs": [],
3942
"source": [
4043
"# set globally here\n",
@@ -54,7 +57,9 @@
5457
{
5558
"cell_type": "code",
5659
"execution_count": null,
57-
"metadata": {},
60+
"metadata": {
61+
"tags": []
62+
},
5863
"outputs": [],
5964
"source": [
6065
"# key used for files and artefacts\n",
@@ -65,7 +70,9 @@
6570
{
6671
"cell_type": "code",
6772
"execution_count": null,
68-
"metadata": {},
73+
"metadata": {
74+
"tags": []
75+
},
6976
"outputs": [],
7077
"source": [
7178
"# set project name. Required to access files and artefacts\n",
@@ -83,7 +90,8 @@
8390
"height": 208
8491
},
8592
"id": "ah1dofx3TdDj",
86-
"outputId": "0bd418dd-6b5d-4fa8-9142-89b22d255e2f"
93+
"outputId": "0bd418dd-6b5d-4fa8-9142-89b22d255e2f",
94+
"tags": []
8795
},
8896
"outputs": [],
8997
"source": [
@@ -98,7 +106,8 @@
98106
"cell_type": "code",
99107
"execution_count": null,
100108
"metadata": {
101-
"id": "WmXtH-PEqyQE"
109+
"id": "WmXtH-PEqyQE",
110+
"tags": []
102111
},
103112
"outputs": [],
104113
"source": [
@@ -118,14 +127,15 @@
118127
{
119128
"cell_type": "code",
120129
"execution_count": null,
121-
"metadata": {},
130+
"metadata": {
131+
"tags": []
132+
},
122133
"outputs": [],
123134
"source": [
124135
"X_test.head()"
125136
]
126137
},
127138
{
128-
"attachments": {},
129139
"cell_type": "markdown",
130140
"metadata": {
131141
"id": "zMIOV1jA_ImH"
@@ -134,6 +144,37 @@
134144
"## FT-Transformer"
135145
]
136146
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": null,
150+
"metadata": {
151+
"tags": []
152+
},
153+
"outputs": [],
154+
"source": [
155+
"def count_parameters(model):\n",
156+
" \"\"\"\n",
157+
" Count number of parameters, that require gradient-update in model.\n",
158+
" \n",
159+
" Found here: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9\n",
160+
" \"\"\"\n",
161+
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
162+
"\n",
163+
"for feature_str, model in tqdm(models):\n",
164+
"\n",
165+
" model_name = model.split(\"/\")[-1].split(\":\")[0]\n",
166+
"\n",
167+
" artifact = run.use_artifact(model)\n",
168+
" model_dir = artifact.download()\n",
169+
" \n",
170+
" with open(Path(model_dir, model_name), 'rb') as f:\n",
171+
" model = pickle.load(f)\n",
172+
" \n",
173+
" print(feature_str)\n",
174+
" print(count_parameters(model.clf))\n",
175+
" "
176+
]
177+
},
137178
{
138179
"cell_type": "code",
139180
"execution_count": null,
@@ -196,7 +237,6 @@
196237
]
197238
},
198239
{
199-
"attachments": {},
200240
"cell_type": "markdown",
201241
"metadata": {},
202242
"source": [

notebooks/6.0e-mb-viz-universal.ipynb

+36-28
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
]
7373
},
7474
{
75-
"attachments": {},
7675
"cell_type": "markdown",
7776
"metadata": {
7877
"id": "kNTG2a_kf5gS"
@@ -124,7 +123,6 @@
124123
]
125124
},
126125
{
127-
"attachments": {},
128126
"cell_type": "markdown",
129127
"metadata": {
130128
"id": "vVE2JK9Af5gW"
@@ -165,7 +163,6 @@
165163
]
166164
},
167165
{
168-
"attachments": {},
169166
"cell_type": "markdown",
170167
"metadata": {
171168
"id": "h9mAHJU1f5gX"
@@ -217,7 +214,6 @@
217214
]
218215
},
219216
{
220-
"attachments": {},
221217
"cell_type": "markdown",
222218
"metadata": {
223219
"id": "rdBVk3fyf5gZ"
@@ -363,7 +359,6 @@
363359
]
364360
},
365361
{
366-
"attachments": {},
367362
"cell_type": "markdown",
368363
"metadata": {
369364
"id": "SyA46Ie6f5gc"
@@ -470,7 +465,6 @@
470465
]
471466
},
472467
{
473-
"attachments": {},
474468
"cell_type": "markdown",
475469
"metadata": {
476470
"id": "KLKHwCjOf5gg"
@@ -594,7 +588,6 @@
594588
]
595589
},
596590
{
597-
"attachments": {},
598591
"cell_type": "markdown",
599592
"metadata": {
600593
"id": "jGL-HbYlf5gi"
@@ -1125,7 +1118,6 @@
11251118
]
11261119
},
11271120
{
1128-
"attachments": {},
11291121
"cell_type": "markdown",
11301122
"metadata": {
11311123
"id": "roRmlg_nf5gl"
@@ -1476,7 +1468,6 @@
14761468
]
14771469
},
14781470
{
1479-
"attachments": {},
14801471
"cell_type": "markdown",
14811472
"metadata": {
14821473
"id": "r5ZnoZIG26K_"
@@ -1574,7 +1565,6 @@
15741565
]
15751566
},
15761567
{
1577-
"attachments": {},
15781568
"cell_type": "markdown",
15791569
"metadata": {
15801570
"id": "hc7pkVqe4qNw"
@@ -1612,6 +1602,17 @@
16121602
"fi = pd.concat([fi_classical, fi_gbm, fi_transformer], axis=1)"
16131603
]
16141604
},
1605+
{
1606+
"cell_type": "code",
1607+
"execution_count": null,
1608+
"metadata": {
1609+
"tags": []
1610+
},
1611+
"outputs": [],
1612+
"source": [
1613+
"fi"
1614+
]
1615+
},
16151616
{
16161617
"cell_type": "code",
16171618
"execution_count": null,
@@ -1626,37 +1627,37 @@
16261627
"ind = np.arange(len(fi))\n",
16271628
"width = 0.25\n",
16281629
"\n",
1629-
"axes[0].barh(ind, fi[\"quote(best)->quote(ex) values\"], width, xerr=fi[\"quote(best)->quote(ex) std\"], label=\"Classical\")\n",
1630-
"axes[0].barh(ind+width, fi[\"gbm(classical) values\"], width, xerr=fi[\"gbm(classical) std\"], label=\"GBRT\")\n",
1631-
"axes[0].barh(ind+width + width, fi[\"fttransformer(classical) values\"], width, xerr=fi[\"fttransformer(classical) std\"], label=\"Transformer\")\n",
1630+
"axes[0].barh(ind, fi[\"quote(best)->quote(ex)->rev_tick(all) values\"].abs(), width, xerr=fi[\"quote(best)->quote(ex)->rev_tick(all) std\"], label=\"Classical\")\n",
1631+
"axes[0].barh(ind+width, fi[\"gbm(classical) values\"].abs(), width, xerr=fi[\"gbm(classical) std\"], label=\"GBRT\")\n",
1632+
"axes[0].barh(ind+width + width, fi[\"fttransformer(classical) values\"].abs(), width, xerr=fi[\"fttransformer(classical) std\"], label=\"Transformer\")\n",
16321633
"axes[0].axvline(0, color='black', linestyle='--', linewidth=0.5)\n",
1633-
"axes[0].set_xlim([-0.15,0.15])\n",
1634+
"axes[0].set_xlim([0,0.15])\n",
16341635
"\n",
1635-
"axes[1].barh(ind, fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) values\"], width, xerr=fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) std\"], label=\"Classical\")\n",
1636-
"axes[1].barh(ind+width, fi[\"gbm(classical-size) values\"], width, xerr=fi[\"gbm(classical-size) std\"], label=\"GBRT\")\n",
1637-
"axes[1].barh(ind+width + width, fi[\"fttransformer(classical-size) values\"], width, xerr=fi[\"fttransformer(classical-size) std\"], label=\"Transformer\")\n",
1636+
"axes[1].barh(ind, fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) values\"].abs(), width, xerr=fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) std\"], label=\"Classical\")\n",
1637+
"axes[1].barh(ind+width, fi[\"gbm(classical-size) values\"].abs(), width, xerr=fi[\"gbm(classical-size) std\"], label=\"GBRT\")\n",
1638+
"axes[1].barh(ind+width + width, fi[\"fttransformer(classical-size) values\"].abs(), width, xerr=fi[\"fttransformer(classical-size) std\"], label=\"Transformer\")\n",
16381639
"axes[1].axvline(0, color='black', linestyle='--', linewidth=0.5)\n",
1639-
"axes[1].set_xlim([-0.15,0.15])\n",
1640+
"axes[1].set_xlim([0,0.15])\n",
16401641
"\n",
1641-
"axes[2].barh(ind, fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) values\"], width, xerr=fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) std\"], label=\"Classical\")\n",
1642-
"axes[2].barh(ind+width, fi[\"gbm(ml) values\"], width, xerr=fi[\"gbm(ml) std\"], label=\"GBRT\")\n",
1643-
"axes[2].barh(ind+width + width, fi[\"fttransformer(ml) values\"], width, xerr=fi[\"fttransformer(ml) std\"], label=\"Transformer\")\n",
1642+
"axes[2].barh(ind, fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) values\"].abs(), width, xerr=fi[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all) std\"], label=\"Classical\")\n",
1643+
"axes[2].barh(ind+width, fi[\"gbm(ml) values\"].abs(), width, xerr=fi[\"gbm(ml) std\"], label=\"GBRT\")\n",
1644+
"axes[2].barh(ind+width + width, fi[\"fttransformer(ml) values\"].abs(), width, xerr=fi[\"fttransformer(ml) std\"], label=\"Transformer\")\n",
16441645
"axes[2].axvline(0, color='black', linestyle='--', linewidth=0.5)\n",
1645-
"axes[2].set_xlim([-0.15,0.15])\n",
1646+
"axes[2].set_xlim([0.0,0.15])\n",
16461647
"\n",
16471648
"\n",
16481649
"# set y-labels\n",
16491650
"labels = ['Price Lead All', 'Price Lag All', 'Price Lead Ex', 'Price Lag Ex', 'Quotes NBBO', 'Quotes Ex', 'Trade Price', \"Quotes Size\", 'Trade Size', 'Strike Price', 'Time To Maturity', 'Option Type', 'Root', 'Moneyness', \"Day Volume\", 'Issue Type']\n",
16501651
"axes[0].set(yticks=ind + width, yticklabels=labels, ylim=[2*width - 1, len(fi)])\n",
16511652
"\n",
16521653
"# set x-labels\n",
1653-
"axes[0].set_xlabel(\"SAGE Value\")\n",
1654-
"axes[1].set_xlabel(\"SAGE Value\")\n",
1655-
"axes[2].set_xlabel(\"SAGE Value\")\n",
1654+
"axes[0].set_xlabel(r\"\\textbar SAGE Value\\textbar\")\n",
1655+
"axes[1].set_xlabel(r\"\\textbar SAGE Value\\textbar\")\n",
1656+
"axes[2].set_xlabel(r\"\\textbar SAGE Value\\textbar\")\n",
16561657
"\n",
16571658
"# set y-labels\n",
16581659
"axes[0].set_title(\"Set Classical\")\n",
1659-
"axes[1].set_title(\"Set Classical-Size\")\n",
1660+
"axes[1].set_title(\"Set Size\")\n",
16601661
"axes[2].set_title(\"Set Options\")\n",
16611662
"\n",
16621663
"handles, labels = axes[0].get_legend_handles_labels()\n",
@@ -1666,14 +1667,21 @@
16661667
"\n",
16671668
"plt.savefig(f\"../reports/Graphs/sage-importances.pdf\", bbox_inches=\"tight\")"
16681669
]
1670+
},
1671+
{
1672+
"cell_type": "code",
1673+
"execution_count": null,
1674+
"metadata": {},
1675+
"outputs": [],
1676+
"source": []
16691677
}
16701678
],
16711679
"metadata": {
16721680
"colab": {
16731681
"provenance": []
16741682
},
16751683
"kernelspec": {
1676-
"display_name": "Python 3",
1684+
"display_name": "Python 3 (ipykernel)",
16771685
"language": "python",
16781686
"name": "python3"
16791687
},
@@ -1687,7 +1695,7 @@
16871695
"name": "python",
16881696
"nbconvert_exporter": "python",
16891697
"pygments_lexer": "ipython3",
1690-
"version": "3.9.4"
1698+
"version": "3.9.7"
16911699
}
16921700
},
16931701
"nbformat": 4,

0 commit comments

Comments
 (0)