-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_trusts.sh
73 lines (61 loc) · 2.07 KB
/
train_trusts.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/bin/bash
# Exit immediately if a command exits with a non-zero status
set -e
# Set common environment variables
export VOCAB_SIZE=32000 # 50304
export BATCH_SIZE=64
export ACC_STEPS=8
export SEQUENCE_LENGTH=512
export DATASET="c4" # "slimpajama"
# # 30M
# export N_LAYER=6
# export N_EMBD=640
# export N_HEAD=5
# export LR=0.0012
# export TOKENS=3000000000 # 3B
# export MODEL_SIZE_PREFIX="30M"
# 50M
export N_LAYER=7
export N_EMBD=768
export N_HEAD=6
export LR=0.0012
export TOKENS=5000000000 # 5B
export MODEL_SIZE_PREFIX="50M"
# Quantization configuration
export W_QUANT="HadamardTrustQuantizer"
export A_QUANT="HadamardTrustQuantizer"
export BITS=1
# Calculate the number of iterations based on tokens and batch settings
export ITERATIONS=$((TOKENS / (BATCH_SIZE * ACC_STEPS * SEQUENCE_LENGTH)))
export WARMUP_STEPS=$((ITERATIONS / 10))
CLIP_SCALE_VALUES=(1.05 1.05 1.15 1.35 1.45 1.50)
NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
# Loop through trust values
for TRUST in "${TRUST_VALUES[@]}"; do
echo "Running with TRUST=${TRUST}"
# Update quantization kwargs with current trust value
export W_QUANT_KWARGS="{\"bits\": ${BITS}, \"trust\": ${TRUST}}"
export A_QUANT_KWARGS="{\"bits\": ${BITS}, \"trust\": ${TRUST}}"
WANDB_PREFIX="UNTIED-${MODEL_SIZE_PREFIX}-${W_QUANT}@${BITS}:${A_QUANT}@${BITS}-${DATASET}-TRUST-${TRUST}"
torchrun --nproc_per_node=${NUM_GPUS} ./src/main.py \
--distributed-backend nccl \
--dataset ${DATASET} \
--model llama \
--compile \
--latest-ckpt-interval 10000 \
--acc-steps ${ACC_STEPS} \
--batch-size ${BATCH_SIZE} \
--wandb \
--wandb-project "llm-baselines" \
--wandb-run-prefix "${WANDB_PREFIX}" \
--n-layer ${N_LAYER} \
--n-embd ${N_EMBD} \
--n-head ${N_HEAD} \
--warmup-steps ${WARMUP_STEPS} \
--iterations ${ITERATIONS} \
--lr ${LR} \
--w-quant ${W_QUANT} \
--w-quant-kwargs "${W_QUANT_KWARGS}" \
--a-quant ${A_QUANT} \
--a-quant-kwargs "${A_QUANT_KWARGS}"
done