Skip to content

Commit 07661b4

Browse files
committed
add mamba dataclass args
1 parent 9b958b2 commit 07661b4

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

open_lm/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class Params:
100100
positional_embedding_type: str = "rotary"
101101
ffn_type: str = "swiglu"
102102

103+
103104
@dataclass
104105
class MambaParams:
105106
d_model: int = 2560
@@ -114,6 +115,7 @@ class MambaParams:
114115
tie_embeddings: bool = True
115116
weight_tying: bool = False
116117

118+
117119
def get_pos_embed(args: Params):
118120
head_dim = args.dim // args.n_heads
119121
if args.positional_embedding_type == "rotary":

open_lm/utils/transformers/hf_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ def __init__(
4040

4141
def set_params(self, params: Params):
4242
self.tie_word_embeddings = params.weight_tying
43-
for field in fields(Params):
43+
for field in fields(params):
4444
setattr(self, field.name, getattr(params, field.name))

0 commit comments

Comments
 (0)