Skip to content

Commit 682d365

Browse files
committed
add custom attn impl
1 parent 9a8f29c commit 682d365

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

open_diloco/train_fsdp.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
116116
class Config(BaseConfig):
117117
llama_config: str | ModelConfig = "open_diloco/configs/config_1b.json"
118118
torch_compile: bool = True
119-
attn_implementation: str = "sdpa"
119+
attention_impl: Literal["sdpa", "fa", "xformers"] = "sdpa"
120120
# Data
121121
dataset_name_or_path: str = "allenai/c4"
122122
seq_length: int = 1024
@@ -184,11 +184,13 @@ def tokenize_function(data):
184184
def get_model(config: Config) -> GPT:
185185
# Load model
186186
if isinstance(config.llama_config, ModelConfig):
187-
return GPT(config.llama_config)
187+
llama_config = config.llama_config
188188
else:
189189
with open(config.llama_config) as f:
190190
llama_config = ModelConfig(**json.load(f))
191-
return GPT(llama_config)
191+
192+
llama_config.attention_impl = config.attention_impl
193+
return GPT(llama_config)
192194

193195

194196
def train(config: Config):

0 commit comments

Comments
 (0)