Skip to content

Commit ebf2182

Browse files
committed
update Mamba compatibility
1 parent c7034d8 commit ebf2182

File tree

3 files changed

+37
-17
lines changed

3 files changed

+37
-17
lines changed

open_lm/model.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,19 @@ class Params:
100100
positional_embedding_type: str = "rotary"
101101
ffn_type: str = "swiglu"
102102

103+
@dataclass
104+
class MambaParams:
105+
d_model: int = None
106+
n_layer: int = None
107+
vocab_size: int = None
108+
seq_len: int = None
109+
ssm_cfg: dict = None
110+
rms_norm: bool = None
111+
residual_in_fp32: bool = None
112+
fused_add_norm: bool = None
113+
pad_vocab_size_multiple: int = None
114+
tie_embeddings: bool = None
115+
weight_tying: bool = None
103116

104117
def get_pos_embed(args: Params):
105118
head_dim = args.dim // args.n_heads
@@ -440,12 +453,19 @@ def create_params(args):
440453
# If a parameter is not in the model config, we use the args parameter
441454

442455
if "mamba" in args.model:
443-
return {
444-
"d_model": cfg["d_model"],
445-
"n_layer": cfg["n_layer"],
446-
"vocab_size": cfg["vocab_size"],
447-
"seq_len": cfg["seq_len"],
448-
}
456+
return MambaParams(
457+
d_model=cfg["d_model"],
458+
n_layer=cfg["n_layer"],
459+
vocab_size=cfg["vocab_size"],
460+
seq_len=cfg["seq_len"],
461+
ssm_cfg={},
462+
rms_norm=cfg["rms_norm"],
463+
residual_in_fp32=cfg["residual_in_fp32"],
464+
fused_add_norm=cfg["fused_add_norm"],
465+
pad_vocab_size_multiple=cfg["pad_vocab_size_multiple"],
466+
tie_embeddings=cfg.get("weight_tying", False),
467+
weight_tying=cfg.get("weight_tying", False),
468+
)
449469
else:
450470
return Params(
451471
dim=cfg["hidden_dim"],
@@ -482,10 +502,10 @@ def __init__(self, params):
482502
)
483503

484504
super().__init__()
485-
self.seq_len = params.pop("seq_len")
486-
self.vocab_size = params["vocab_size"]
505+
self.vocab_size = params.vocab_size
506+
self.seq_len = params.seq_len
487507

488-
self.model = MambaLMHeadModel(**params)
508+
self.model = MambaLMHeadModel(params)
489509

490510
def reset_parameters(self):
491511
return

open_lm/model_configs/mamba_7b.json

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,10 @@
22
"d_model": 4096,
33
"n_layer": 64,
44
"vocab_size": 50432,
5-
"seq_len": 2048
6-
}
5+
"seq_len": 2048,
6+
"ssm_cfg": {},
7+
"rms_norm": true,
8+
"residual_in_fp32": true,
9+
"fused_add_norm": true,
10+
"pad_vocab_size_multiple": 8
11+
}

open_lm/utils/transformers/hf_model.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def forward(
105105
use_cache=use_cache,
106106
attention_mask=attention_mask,
107107
)
108-
109108
loss = None
110109
if labels is not None:
111110
shift_logits = logits[..., :-1, :].contiguous()
@@ -115,11 +114,7 @@ def forward(
115114
shift_labels = shift_labels.view(-1).to(shift_logits.device)
116115
loss = loss_fct(shift_logits, shift_labels)
117116

118-
output = CausalLMOutputWithPast(
119-
logits=logits,
120-
past_key_values=past_key_values,
121-
loss=loss
122-
)
117+
output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss)
123118
return output
124119

125120
def prepare_inputs_for_generation(

0 commit comments

Comments
 (0)