1
1
import sys
2
- import tensorflow as tf
2
+
3
3
import tensorflow .compat .v1 as tfv1
4
4
5
5
from .flags import FLAGS
6
- from .logging import log_info , log_error , log_warn
6
+ from .logging import log_error , log_info , log_warn
7
7
8
8
9
9
def _load_checkpoint (session , checkpoint_path , allow_drop_layers , allow_lr_init = True ):
@@ -19,47 +19,33 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
19
19
# compatibility with older checkpoints.
20
20
lr_var = set (v for v in load_vars if v .op .name == 'learning_rate' )
21
21
if lr_var and ('learning_rate' not in vars_in_ckpt or
22
- (FLAGS .force_initialize_learning_rate and allow_lr_init )):
22
+ (FLAGS .force_initialize_learning_rate and allow_lr_init )):
23
23
assert len (lr_var ) <= 1
24
24
load_vars -= lr_var
25
25
init_vars |= lr_var
26
26
27
- if FLAGS .load_cudnn :
28
- # Initialize training from a CuDNN RNN checkpoint
29
- # Identify the variables which we cannot load, and set them
30
- # for initialization
31
- missing_vars = set ()
32
- for v in load_vars :
33
- if v .op .name not in vars_in_ckpt :
34
- log_warn ('CUDNN variable not found: %s' % (v .op .name ))
35
- missing_vars .add (v )
27
+ # After training with "freeze_source_layers" the Adam moment tensors for the frozen layers
28
+ # are missing because they were not used. This might also occur when loading a cudnn checkpoint
29
+ # Therefore we have to initialize them again to continue training on such checkpoints
30
+ print_msg = False
31
+ for v in load_vars :
32
+ if v .op .name not in vars_in_ckpt :
33
+ if 'Adam' in v .name :
36
34
init_vars .add (v )
35
+ print_msg = True
36
+ if print_msg :
37
+ msg = "Some Adam tensors are missing, they will be initialized automatically."
38
+ log_info (msg )
39
+ load_vars -= init_vars
37
40
38
- load_vars -= init_vars
39
-
40
- # Check that the only missing variables (i.e. those to be initialised)
41
- # are the Adam moment tensors, if they aren't then we have an issue
42
- missing_var_names = [v .op .name for v in missing_vars ]
43
- if any ('Adam' not in v for v in missing_var_names ):
44
- log_error ('Tried to load a CuDNN RNN checkpoint but there were '
45
- 'more missing variables than just the Adam moment '
46
- 'tensors. Missing variables: {}' .format (missing_var_names ))
47
- sys .exit (1 )
48
-
49
- if FLAGS .load_frozen_graph :
50
- # After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
51
- # existing anymore because they were not used
52
- # Therefore we have to initialize them again to continue training on such checkpoints
41
+ if FLAGS .load_cudnn :
42
+ # Check all required tensors are included in the cudnn checkpoint we want to load
53
43
for v in load_vars :
54
- if v .op .name not in vars_in_ckpt :
55
- if 'Adam' in v .name :
56
- init_vars .add (v )
57
- else :
58
- msg = "Tried to load a frozen checkpoint but there was a missing " \
59
- "variable other than the Adam tensors: {}"
60
- log_error (msg .format (v ))
61
- sys .exit (1 )
62
- load_vars -= init_vars
44
+ if v .op .name not in vars_in_ckpt or 'Adam' not in v .op .name :
45
+ msg = 'Tried to load a CuDNN RNN checkpoint but there was a missing' \
46
+ ' variable other than an Adam moment tensor: {}'
47
+ log_error (msg .format (v .op .name ))
48
+ sys .exit (1 )
63
49
64
50
if allow_drop_layers and FLAGS .drop_source_layers > 0 :
65
51
# This transfer learning approach requires supplying
@@ -74,7 +60,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
74
60
'dropping only 5 layers.' )
75
61
FLAGS .drop_source_layers = 5
76
62
77
- dropped_layers = [ '2' , '3' , 'lstm' , '5' , '6' ][ - 1 * int (FLAGS .drop_source_layers ):]
63
+ dropped_layers = drop_freeze_number_to_layers (FLAGS .drop_source_layers , "drop" )
78
64
# Initialize all variables needed for DS, but not loaded from ckpt
79
65
for v in load_vars :
80
66
if any (layer in v .op .name for layer in dropped_layers ):
@@ -90,6 +76,24 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
90
76
session .run (v .initializer )
91
77
92
78
79
+ def drop_freeze_number_to_layers (drop_freeze_number , mode ):
80
+ """ Convert number of layers to drop or freeze into layer names """
81
+
82
+ if drop_freeze_number >= 6 :
83
+ log_warn ('The checkpoint only has 6 layers, but you are trying '
84
+ 'to drop or freeze all of them or more. Continuing with 5 layers.' )
85
+ drop_freeze_number = 5
86
+
87
+ layer_keys = ["layer_1" , "layer_2" , "layer_3" , "lstm" , "layer_5" , "layer_6" ]
88
+ if mode == "drop" :
89
+ layer_keys = layer_keys [- 1 * int (drop_freeze_number ):]
90
+ elif mode == "freeze" :
91
+ layer_keys = layer_keys [:- 1 * int (drop_freeze_number )]
92
+ else :
93
+ raise ValueError
94
+ return layer_keys
95
+
96
+
93
97
def _checkpoint_path_or_none (checkpoint_filename ):
94
98
checkpoint = tfv1 .train .get_checkpoint_state (FLAGS .load_checkpoint_dir , checkpoint_filename )
95
99
if not checkpoint :
0 commit comments