1
- # Copyright (c) 2021-2022 , NVIDIA CORPORATION. All rights reserved.
1
+ # Copyright (c) 2021-2024 , NVIDIA CORPORATION. All rights reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
13
13
# limitations under the License.
14
14
15
15
import os
16
+ import mlflow
16
17
import pandas as pd
17
18
18
19
from omegaconf import OmegaConf
19
20
from hydra .experimental .callback import Callback
20
21
21
22
from loggers .log_helper import jsonlog_2_df
23
+ from mlflow .entities import Metric , Param
22
24
23
25
class MergeLogs (Callback ):
24
26
def on_multirun_end (self , config , ** kwargs ):
25
27
OmegaConf .resolve (config )
26
28
27
- ALLOWED_KEYS = ['timestamp' , 'elapsed_time' , 'step' , 'loss' , 'val_loss' , 'MAE' , 'MSE' , 'RMSE' , 'P50' , 'P90' ]
29
+ ALLOWED_KEYS = ['timestamp' , 'elapsed_time' , 'step' , 'loss' , 'val_loss' , 'MAE' , 'MSE' , 'RMSE' , 'P50' , 'P90' , 'SMAPE' , 'TDI' ]
28
30
29
31
dfs = []
30
32
for p , sub_dirs , files in os .walk (config .hydra .sweep .dir ):
31
33
if 'log.json' in files :
32
34
path = os .path .join (p , 'log.json' )
33
35
df = jsonlog_2_df (path , ALLOWED_KEYS )
34
36
dfs .append (df )
35
-
36
37
# Transpose dataframes
37
38
plots = {}
38
39
for c in dfs [0 ].columns :
@@ -49,3 +50,15 @@ def on_multirun_end(self, config, **kwargs):
49
50
timestamps = (timestamps * 1000 ).astype (int )
50
51
if not timestamps .is_monotonic :
51
52
raise ValueError ('Timestamps are not monotonic' )
53
+
54
+ metrics = [Metric ('_' .join ((k ,name )), v , timestamp , step )
55
+ for k , df in plots .items ()
56
+ for timestamp , (step , series ) in zip (timestamps , df .iterrows ())
57
+ for name , v in series .items ()
58
+ ]
59
+ client = mlflow .tracking .MlflowClient (tracking_uri = config .trainer .config .mlflow_store )
60
+ exp = client .get_experiment_by_name (config .trainer .config .get ('experiment_name' , '' ))
61
+ run = client .create_run (exp .experiment_id if exp else '0' )
62
+ for i in range (0 , len (metrics ), 1000 ):
63
+ client .log_batch (run .info .run_id , metrics = metrics [i :i + 1000 ])
64
+ client .set_terminated (run .info .run_id )
0 commit comments