@@ -100,6 +100,19 @@ class Params:
100
100
positional_embedding_type : str = "rotary"
101
101
ffn_type : str = "swiglu"
102
102
103
+ @dataclass
104
+ class MambaParams :
105
+ d_model : int = None
106
+ n_layer : int = None
107
+ vocab_size : int = None
108
+ seq_len : int = None
109
+ ssm_cfg : dict = None
110
+ rms_norm : bool = None
111
+ residual_in_fp32 : bool = None
112
+ fused_add_norm : bool = None
113
+ pad_vocab_size_multiple : int = None
114
+ tie_embeddings : bool = None
115
+ weight_tying : bool = None
103
116
104
117
def get_pos_embed (args : Params ):
105
118
head_dim = args .dim // args .n_heads
@@ -440,12 +453,19 @@ def create_params(args):
440
453
# If a parameter is not in the model config, we use the args parameter
441
454
442
455
if "mamba" in args .model :
443
- return {
444
- "d_model" : cfg ["d_model" ],
445
- "n_layer" : cfg ["n_layer" ],
446
- "vocab_size" : cfg ["vocab_size" ],
447
- "seq_len" : cfg ["seq_len" ],
448
- }
456
+ return MambaParams (
457
+ d_model = cfg ["d_model" ],
458
+ n_layer = cfg ["n_layer" ],
459
+ vocab_size = cfg ["vocab_size" ],
460
+ seq_len = cfg ["seq_len" ],
461
+ ssm_cfg = {},
462
+ rms_norm = cfg ["rms_norm" ],
463
+ residual_in_fp32 = cfg ["residual_in_fp32" ],
464
+ fused_add_norm = cfg ["fused_add_norm" ],
465
+ pad_vocab_size_multiple = cfg ["pad_vocab_size_multiple" ],
466
+ tie_embeddings = cfg .get ("weight_tying" , False ),
467
+ weight_tying = cfg .get ("weight_tying" , False ),
468
+ )
449
469
else :
450
470
return Params (
451
471
dim = cfg ["hidden_dim" ],
@@ -482,10 +502,10 @@ def __init__(self, params):
482
502
)
483
503
484
504
super ().__init__ ()
485
- self .seq_len = params .pop ( "seq_len" )
486
- self .vocab_size = params [ "vocab_size" ]
505
+ self .vocab_size = params .vocab_size
506
+ self .seq_len = params . seq_len
487
507
488
- self .model = MambaLMHeadModel (** params )
508
+ self .model = MambaLMHeadModel (params )
489
509
490
510
def reset_parameters (self ):
491
511
return
0 commit comments