This repository has been archived by the owner on Oct 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 27
update palm by combining gemini #49
Open
ZijianYY
wants to merge
1
commit into
hpcaitech:main
Choose a base branch
from
ZijianYY:colossalai
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -141,4 +141,10 @@ docs/.build | |
|
||
# pytorch checkpoint | ||
*.pt | ||
*.ckpt | ||
*.ckpt | ||
|
||
# token | ||
token/ | ||
|
||
# dataset | ||
wiki_dataset/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
export TOKENIZER=./token | ||
export DATA=./wiki_dataset | ||
env OMP_NUM_THREADS=12 torchrun --nproc_per_node 4 --master_port 29501 train.py --from_torch --config ./configs/palm_30b_2d.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
from asyncio.log import logger | ||
import contextlib | ||
import os | ||
from packaging import version | ||
from functools import partial | ||
from time import time | ||
|
||
import colossalai | ||
import torch | ||
|
@@ -9,8 +12,12 @@ | |
from colossalai.trainer import Trainer, hooks | ||
from colossalai.utils import MultiTimer, get_current_device | ||
from colossalai.zero.init_ctx import ZeroInitContext | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec | ||
from colossalai.nn.optimizer import HybridAdam | ||
from colossalai.nn.parallel import ZeroDDP | ||
from colossalai.context import ParallelMode | ||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer | ||
|
||
from data import build_data | ||
from model import build_loss, build_model | ||
|
@@ -24,6 +31,35 @@ def limit_cuda_memory(size_in_GB: int): | |
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) | ||
logger = get_dist_logger() | ||
logger.info("Using {} GB of GPU memory".format(size_in_GB)) | ||
|
||
|
||
def get_tflops(model_numel, batch_size, seq_len, step_time): | ||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) | ||
|
||
|
||
|
||
# Gemini + ZeRO DDP | ||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): | ||
cai_version = colossalai.__version__ | ||
if version.parse(cai_version) > version.parse("0.1.10"): | ||
from colossalai.nn.parallel import GeminiDDP | ||
model = GeminiDDP(model, | ||
device=get_current_device(), | ||
placement_policy=placememt_policy, | ||
pin_memory=True, | ||
search_range_mb=32) | ||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): | ||
from colossalai.gemini import ChunkManager, GeminiManager | ||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) | ||
gemini_manager = GeminiManager(placememt_policy, chunk_manager) | ||
chunk_manager = ChunkManager(chunk_size, | ||
pg, | ||
enable_distributed_storage=True, | ||
init_device=GeminiManager.get_default_device(placememt_policy)) | ||
model = ZeroDDP(model, gemini_manager) | ||
else: | ||
raise NotImplemented(f"CAI version {cai_version} is not supported") | ||
return model | ||
|
||
def train_palm(): | ||
assert torch.cuda.is_available() | ||
|
@@ -55,47 +91,35 @@ def train_palm(): | |
assert hasattr(gpc.config, "NUM_EPOCHS"), "Please provide NUM_EPOCHS in your configuration" | ||
|
||
use_zero = hasattr(gpc.config, "zero") | ||
ctx = contextlib.nullcontext() | ||
#ctx = contextlib.nullcontext() | ||
tflop = 0 | ||
default_pg = ProcessGroup(tp_degree=gpc.config.TPDEGREE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要使用gpc.config了。它是老接口的配置文件。tp_degree可以通过args设置。 |
||
default_dist_spec = ShardSpec([-1], [gpc.config.TPDEGREE]) if gpc.config.USE_SHARD_INIT else None | ||
if use_zero: | ||
ctx = ZeroInitContext( | ||
target_device=torch.cuda.current_device(), | ||
shard_strategy=gpc.config.zero.model_config.shard_strategy, | ||
shard_param=True, | ||
) | ||
# ctx = ZeroInitContext( | ||
# target_device=torch.cuda.current_device(), | ||
# shard_strategy=gpc.config.zero.model_config.shard_strategy, | ||
# shard_param=True, | ||
# ) | ||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) | ||
with ctx: | ||
model = build_model() | ||
model = AutoregressiveWrapper(model) | ||
|
||
|
||
|
||
logger = get_dist_logger() | ||
if hasattr(gpc.config, "LOG_PATH"): | ||
log_path = gpc.config.LOG_PATH | ||
logger.log_to_file(log_path) | ||
# if hasattr(gpc.config, "LOG_PATH"): | ||
# log_path = gpc.config.LOG_PATH | ||
# logger.log_to_file(log_path) | ||
|
||
with ctx: | ||
model = build_model() | ||
model = AutoregressiveWrapper(model) | ||
# with ctx: | ||
# model = build_model() | ||
# model = AutoregressiveWrapper(model) | ||
|
||
seq_len=gpc.config.SEQ_LENGTH | ||
batch_size=gpc.config.BATCH_SIZE | ||
|
||
# numel is a model elem in a DP process. | ||
numel = 0 | ||
if use_zero: | ||
numel = ctx.model_numel_tensor.item() | ||
else: | ||
numel = calc_local_model_size(model) | ||
|
||
tflop = numel * batch_size * seq_len \ | ||
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) | ||
|
||
if numel < 1e9: | ||
msg = f"{numel / 1e6:.3f} M" | ||
else: | ||
msg = f"{numel / 1e9:.3f} B" | ||
|
||
model_mem = torch.cuda.max_memory_allocated(get_current_device()) / 1024**3 | ||
|
||
logger.info("Model is built.", ranks=[0]) | ||
logger.info(f"Parameter size = {msg} | Model memory = {model_mem:.3f} GB.", ranks=[0]) | ||
|
||
criterion = build_loss() | ||
logger.info("Loss is built.", ranks=[0]) | ||
|
||
|
@@ -110,66 +134,76 @@ def train_palm(): | |
|
||
# We use a fast CPU Adam here | ||
# If we set cpu_offload=True in optimizer_config | ||
use_cpu_adam = ( | ||
hasattr(gpc.config, "zero") | ||
and hasattr(gpc.config.zero, "model_config") | ||
and getattr(gpc.config.zero.model_config, "tensor_placement_policy") != "cuda" | ||
) | ||
optimizer = HybridAdam if use_cpu_adam else torch.optim.AdamW | ||
optimizer = optimizer(model.parameters(), lr=0.001, weight_decay=1e-2) | ||
# use_cpu_adam = ( | ||
# hasattr(gpc.config, "zero") | ||
# and hasattr(gpc.config.zero, "model_config") | ||
# and getattr(gpc.config.zero.model_config, "tensor_placement_policy") != "cuda" | ||
# ) | ||
# optimizer = HybridAdam if use_cpu_adam else torch.optim.AdamW | ||
# optimizer = optimizer(model.parameters(), lr=0.001, weight_decay=1e-2) | ||
pg = default_pg | ||
# Tensor Parallelism (TP) | ||
#tensor_parallelize(model, pg) | ||
# Gemini + ZeRO DP, Note it must be used after TP | ||
model = gemini_zero_dpp(model, pg, gpc.config.placement) | ||
|
||
# build optimizer | ||
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) | ||
|
||
# total_steps = gpc.config.NUM_EPOCHS * len(train_dataloader) | ||
# warmup_steps = getattr(gpc.config, "WARMUP_EPOCHS", 0) * len(train_dataloader) | ||
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) | ||
|
||
logger.info("Optimizer is built.", ranks=[0]) | ||
|
||
engine, train_dataloader, _, _ = colossalai.initialize( | ||
model=model, | ||
optimizer=optimizer, | ||
criterion=criterion, | ||
# lr_scheduler=lr_scheduler, | ||
train_dataloader=train_dataloader, | ||
) | ||
#logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) | ||
|
||
def batch_data_process_func(batch_data): | ||
data = batch_data["input_ids"] | ||
labels = batch_data["labels"] | ||
return data, labels | ||
|
||
engine.schedule.batch_data_process_func = batch_data_process_func | ||
|
||
timer = MultiTimer() | ||
trainer = Trainer(engine=engine, logger=logger, timer=timer) | ||
|
||
hook_list = [ | ||
hooks.LogMetricByEpochHook(logger=logger), | ||
hooks.LogMetricByStepHook(), | ||
hooks.LossHook(), | ||
hooks.ThroughputHook(ignored_steps=10, tflop_per_step = tflop), | ||
# hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), | ||
hooks.LogMemoryByEpochHook(logger), | ||
# hooks.SaveCheckpointHook(checkpoint_dir="./palm.ckpt", model=model), | ||
] | ||
|
||
logger.info("Training start.", ranks=[0]) | ||
trainer.fit( | ||
train_dataloader=train_dataloader, | ||
epochs=gpc.config.NUM_EPOCHS, | ||
max_steps=20, | ||
hooks=hook_list, | ||
return_output_label=False, | ||
display_progress=True, | ||
) | ||
#numel is a model elem in a DP process. | ||
numel = 0 | ||
if use_zero: | ||
#numel = ctx.model_numel_tensor.item() | ||
numel = sum([p.numel() for p in model.parameters()]) | ||
else: | ||
numel = calc_local_model_size(model) | ||
|
||
tflop = numel * batch_size * seq_len \ | ||
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) | ||
|
||
opt_state = engine.optimizer.state_dict() | ||
if isinstance(engine.optimizer, colossalai.amp.naive_amp.NaiveAMPOptimizer): | ||
opt_state = opt_state['optimizer'] | ||
os_mem = calc_mem(opt_state) | ||
logger.info(f"{engine.optimizer.__class__.__name__} state memory usage = {os_mem / 1024**2:.3f} MB", ranks=[0]) | ||
get_tflops_func = partial(get_tflops, numel, batch_size, seq_len) | ||
|
||
gpc.destroy() | ||
logger.info("Training complete.", ranks=[0]) | ||
if numel < 1e9: | ||
msg = f"{numel / 1e6:.3f} M" | ||
else: | ||
msg = f"{numel / 1e9:.3f} B" | ||
|
||
model_mem = torch.cuda.max_memory_allocated(get_current_device()) / 1024**3 | ||
|
||
logger.info("Model is built.", ranks=[0]) | ||
logger.info(f"Parameter size = {msg} | Model memory = {model_mem:.3f} GB.", ranks=[0]) | ||
|
||
torch.cuda.synchronize() | ||
model.train() | ||
for n in range(gpc.config.NUM_EPOCHS): | ||
for step, batch in enumerate(train_dataloader): | ||
input_ids = batch["input_ids"].cuda() | ||
labels = batch["labels"].cuda() | ||
optimizer.zero_grad() | ||
start = time() | ||
outputs = model(input_ids) | ||
loss = criterion(outputs, labels) | ||
#logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Forward '), ranks=[0]) | ||
|
||
optimizer.backward(loss) | ||
|
||
#logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Backward '), ranks=[0]) | ||
optimizer.step() | ||
#logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Optimizer step '), ranks=[0]) | ||
step_time = time() - start | ||
logger.info( | ||
f'[{n+1}/{gpc.config.NUM_EPOCHS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', | ||
ranks=[0]) | ||
|
||
torch.cuda.synchronize() | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那不需要if判断了,两个分支逻辑一模一样