Commit 9b958b2 1 parent ebf2182 commit 9b958b2 Copy full SHA for 9b958b2
File tree 1 file changed +12
-12
lines changed
1 file changed +12
-12
lines changed Original file line number Diff line number Diff line change 3
3
import re
4
4
from copy import deepcopy
5
5
from pathlib import Path
6
- from dataclasses import dataclass
6
+ from dataclasses import dataclass , field
7
7
from typing import Callable
8
8
9
9
import torch
@@ -102,17 +102,17 @@ class Params:
102
102
103
103
@dataclass
104
104
class MambaParams :
105
- d_model : int = None
106
- n_layer : int = None
107
- vocab_size : int = None
108
- seq_len : int = None
109
- ssm_cfg : dict = None
110
- rms_norm : bool = None
111
- residual_in_fp32 : bool = None
112
- fused_add_norm : bool = None
113
- pad_vocab_size_multiple : int = None
114
- tie_embeddings : bool = None
115
- weight_tying : bool = None
105
+ d_model : int = 2560
106
+ n_layer : int = 64
107
+ vocab_size : int = 50277
108
+ seq_len : int = 2048
109
+ ssm_cfg : dict = field ( default_factory = dict )
110
+ rms_norm : bool = True
111
+ residual_in_fp32 : bool = True
112
+ fused_add_norm : bool = True
113
+ pad_vocab_size_multiple : int = 8
114
+ tie_embeddings : bool = True
115
+ weight_tying : bool = False
116
116
117
117
def get_pos_embed (args : Params ):
118
118
head_dim = args .dim // args .n_heads
You can’t perform that action at this time.
0 commit comments