Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

update palm by combining gemini #49

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,10 @@ docs/.build

# pytorch checkpoint
*.pt
*.ckpt
*.ckpt

# token
token/

# dataset
wiki_dataset/
5 changes: 4 additions & 1 deletion configs/palm_30b_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
BATCH_SIZE = 2
NUM_EPOCHS = 10
WARMUP_EPOCHS = 1
TPDEGREE = 2
USE_SHARD_INIT = False
placement = 'cpu'

parallel = dict(
tensor=dict(mode="2d", size=4),
Expand All @@ -27,4 +30,4 @@

clip_grad_norm = 1.0

LOG_PATH = "./palm_30b_2d/"
LOG_PATH = "./palm_30b_2d_new/"
14 changes: 9 additions & 5 deletions model/parallel_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from colossalai.core import global_context as gpc
from colossalai.nn import CheckpointModule
from einops import rearrange
from torch import dtype, einsum
from torch import dtype, einsum, matmul

from model.palm_utils import RotaryEmbedding, SwiGLU, apply_rotary_pos_emb
from model.parallel_utils import (
Expand Down Expand Up @@ -132,11 +132,13 @@ def _forward(self, x):

# calculate similarity
if self.multi_query:
sim = einsum("b h s d, b j d -> b h s j", q, k)
#sim = einsum("b h s d, b j d -> b h s j", q, k)
sim = matmul(q, k.transpose(1,2))
else:
# s and n here refer to sequence length
# n is used only because einsum cannot have 2 same notations
sim = einsum("b h s d, b h n d -> b h s n", q, k)
#sim = einsum("b h s d, b h n d -> b h s n", q, k)
sim = matmul(q, k.transpose(2,3))

# apply casual mask
causal_mask = self.get_mask(seq_length, device)
Expand All @@ -148,9 +150,11 @@ def _forward(self, x):

# aggregate values
if self.multi_query:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那不需要if判断了,两个分支逻辑一模一样

attn_out = einsum("b h i j, b j d -> b h i d", attn, v)
#attn_out = einsum("b h i j, b j d -> b h i d", attn, v)
attn_out = matmul(attn, v)
else:
attn_out = einsum("b h s n, b h n d -> b h s d", attn, v)
#attn_out = einsum("b h s n, b h n d -> b h s d", attn, v)
attn_out = matmul(attn, v)

# merge heads
attn_out = rearrange(attn_out, "b h s d -> b s (h d)")
Expand Down
3 changes: 3 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export TOKENIZER=./token
export DATA=./wiki_dataset
env OMP_NUM_THREADS=12 torchrun --nproc_per_node 4 --master_port 29501 train.py --from_torch --config ./configs/palm_30b_2d.py
198 changes: 116 additions & 82 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from asyncio.log import logger
import contextlib
import os
from packaging import version
from functools import partial
from time import time

import colossalai
import torch
Expand All @@ -9,8 +12,12 @@
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.context import ParallelMode
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer

from data import build_data
from model import build_loss, build_model
Expand All @@ -24,6 +31,35 @@ def limit_cuda_memory(size_in_GB: int):
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
logger = get_dist_logger()
logger.info("Using {} GB of GPU memory".format(size_in_GB))


def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)



# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model

def train_palm():
assert torch.cuda.is_available()
Expand Down Expand Up @@ -55,47 +91,35 @@ def train_palm():
assert hasattr(gpc.config, "NUM_EPOCHS"), "Please provide NUM_EPOCHS in your configuration"

use_zero = hasattr(gpc.config, "zero")
ctx = contextlib.nullcontext()
#ctx = contextlib.nullcontext()
tflop = 0
default_pg = ProcessGroup(tp_degree=gpc.config.TPDEGREE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要使用gpc.config了。它是老接口的配置文件。tp_degree可以通过args设置。

default_dist_spec = ShardSpec([-1], [gpc.config.TPDEGREE]) if gpc.config.USE_SHARD_INIT else None
if use_zero:
ctx = ZeroInitContext(
target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True,
)
# ctx = ZeroInitContext(
# target_device=torch.cuda.current_device(),
# shard_strategy=gpc.config.zero.model_config.shard_strategy,
# shard_param=True,
# )
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx:
model = build_model()
model = AutoregressiveWrapper(model)



logger = get_dist_logger()
if hasattr(gpc.config, "LOG_PATH"):
log_path = gpc.config.LOG_PATH
logger.log_to_file(log_path)
# if hasattr(gpc.config, "LOG_PATH"):
# log_path = gpc.config.LOG_PATH
# logger.log_to_file(log_path)

with ctx:
model = build_model()
model = AutoregressiveWrapper(model)
# with ctx:
# model = build_model()
# model = AutoregressiveWrapper(model)

seq_len=gpc.config.SEQ_LENGTH
batch_size=gpc.config.BATCH_SIZE

# numel is a model elem in a DP process.
numel = 0
if use_zero:
numel = ctx.model_numel_tensor.item()
else:
numel = calc_local_model_size(model)

tflop = numel * batch_size * seq_len \
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)

if numel < 1e9:
msg = f"{numel / 1e6:.3f} M"
else:
msg = f"{numel / 1e9:.3f} B"

model_mem = torch.cuda.max_memory_allocated(get_current_device()) / 1024**3

logger.info("Model is built.", ranks=[0])
logger.info(f"Parameter size = {msg} | Model memory = {model_mem:.3f} GB.", ranks=[0])

criterion = build_loss()
logger.info("Loss is built.", ranks=[0])

Expand All @@ -110,66 +134,76 @@ def train_palm():

# We use a fast CPU Adam here
# If we set cpu_offload=True in optimizer_config
use_cpu_adam = (
hasattr(gpc.config, "zero")
and hasattr(gpc.config.zero, "model_config")
and getattr(gpc.config.zero.model_config, "tensor_placement_policy") != "cuda"
)
optimizer = HybridAdam if use_cpu_adam else torch.optim.AdamW
optimizer = optimizer(model.parameters(), lr=0.001, weight_decay=1e-2)
# use_cpu_adam = (
# hasattr(gpc.config, "zero")
# and hasattr(gpc.config.zero, "model_config")
# and getattr(gpc.config.zero.model_config, "tensor_placement_policy") != "cuda"
# )
# optimizer = HybridAdam if use_cpu_adam else torch.optim.AdamW
# optimizer = optimizer(model.parameters(), lr=0.001, weight_decay=1e-2)
pg = default_pg
# Tensor Parallelism (TP)
#tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, gpc.config.placement)

# build optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)

# total_steps = gpc.config.NUM_EPOCHS * len(train_dataloader)
# warmup_steps = getattr(gpc.config, "WARMUP_EPOCHS", 0) * len(train_dataloader)
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)

logger.info("Optimizer is built.", ranks=[0])

engine, train_dataloader, _, _ = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
# lr_scheduler=lr_scheduler,
train_dataloader=train_dataloader,
)
#logger.info(get_mem_info(prefix='After init model, '), ranks=[0])

def batch_data_process_func(batch_data):
data = batch_data["input_ids"]
labels = batch_data["labels"]
return data, labels

engine.schedule.batch_data_process_func = batch_data_process_func

timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)

hook_list = [
hooks.LogMetricByEpochHook(logger=logger),
hooks.LogMetricByStepHook(),
hooks.LossHook(),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step = tflop),
# hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
hooks.LogMemoryByEpochHook(logger),
# hooks.SaveCheckpointHook(checkpoint_dir="./palm.ckpt", model=model),
]

logger.info("Training start.", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
epochs=gpc.config.NUM_EPOCHS,
max_steps=20,
hooks=hook_list,
return_output_label=False,
display_progress=True,
)
#numel is a model elem in a DP process.
numel = 0
if use_zero:
#numel = ctx.model_numel_tensor.item()
numel = sum([p.numel() for p in model.parameters()])
else:
numel = calc_local_model_size(model)

tflop = numel * batch_size * seq_len \
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)

opt_state = engine.optimizer.state_dict()
if isinstance(engine.optimizer, colossalai.amp.naive_amp.NaiveAMPOptimizer):
opt_state = opt_state['optimizer']
os_mem = calc_mem(opt_state)
logger.info(f"{engine.optimizer.__class__.__name__} state memory usage = {os_mem / 1024**2:.3f} MB", ranks=[0])
get_tflops_func = partial(get_tflops, numel, batch_size, seq_len)

gpc.destroy()
logger.info("Training complete.", ranks=[0])
if numel < 1e9:
msg = f"{numel / 1e6:.3f} M"
else:
msg = f"{numel / 1e9:.3f} B"

model_mem = torch.cuda.max_memory_allocated(get_current_device()) / 1024**3

logger.info("Model is built.", ranks=[0])
logger.info(f"Parameter size = {msg} | Model memory = {model_mem:.3f} GB.", ranks=[0])

torch.cuda.synchronize()
model.train()
for n in range(gpc.config.NUM_EPOCHS):
for step, batch in enumerate(train_dataloader):
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
optimizer.zero_grad()
start = time()
outputs = model(input_ids)
loss = criterion(outputs, labels)
#logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Forward '), ranks=[0])

optimizer.backward(loss)

#logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Backward '), ranks=[0])
optimizer.step()
#logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Optimizer step '), ranks=[0])
step_time = time() - start
logger.info(
f'[{n+1}/{gpc.config.NUM_EPOCHS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])

torch.cuda.synchronize()


if __name__ == "__main__":
Expand Down