Skip to content

Commit 9b958b2

Browse files
committed
MambaParams default values
1 parent ebf2182 commit 9b958b2

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

open_lm/model.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from copy import deepcopy
55
from pathlib import Path
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from typing import Callable
88

99
import torch
@@ -102,17 +102,17 @@ class Params:
102102

103103
@dataclass
104104
class MambaParams:
105-
d_model: int = None
106-
n_layer: int = None
107-
vocab_size: int = None
108-
seq_len: int = None
109-
ssm_cfg: dict = None
110-
rms_norm: bool = None
111-
residual_in_fp32: bool = None
112-
fused_add_norm: bool = None
113-
pad_vocab_size_multiple: int = None
114-
tie_embeddings: bool = None
115-
weight_tying: bool = None
105+
d_model: int = 2560
106+
n_layer: int = 64
107+
vocab_size: int = 50277
108+
seq_len: int = 2048
109+
ssm_cfg: dict = field(default_factory=dict)
110+
rms_norm: bool = True
111+
residual_in_fp32: bool = True
112+
fused_add_norm: bool = True
113+
pad_vocab_size_multiple: int = 8
114+
tie_embeddings: bool = True
115+
weight_tying: bool = False
116116

117117
def get_pos_embed(args: Params):
118118
head_dim = args.dim // args.n_heads

0 commit comments

Comments
 (0)