diff --git a/.gitignore b/.gitignore index 9f6bcdc..0fef89a 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,10 @@ docs/.build # pytorch checkpoint *.pt -*.ckpt \ No newline at end of file +*.ckpt + +# token +token/ + +# dataset +wiki_dataset/ \ No newline at end of file diff --git a/configs/palm_30b_2d.py b/configs/palm_30b_2d.py index c61f09d..abd818e 100644 --- a/configs/palm_30b_2d.py +++ b/configs/palm_30b_2d.py @@ -4,6 +4,9 @@ BATCH_SIZE = 2 NUM_EPOCHS = 10 WARMUP_EPOCHS = 1 +TPDEGREE = 2 +USE_SHARD_INIT = False +placement = 'cpu' parallel = dict( tensor=dict(mode="2d", size=4), @@ -27,4 +30,4 @@ clip_grad_norm = 1.0 -LOG_PATH = "./palm_30b_2d/" +LOG_PATH = "./palm_30b_2d_new/" diff --git a/model/parallel_palm.py b/model/parallel_palm.py index 6897ef4..14c15b7 100644 --- a/model/parallel_palm.py +++ b/model/parallel_palm.py @@ -5,7 +5,7 @@ from colossalai.core import global_context as gpc from colossalai.nn import CheckpointModule from einops import rearrange -from torch import dtype, einsum +from torch import dtype, einsum, matmul from model.palm_utils import RotaryEmbedding, SwiGLU, apply_rotary_pos_emb from model.parallel_utils import ( @@ -132,11 +132,13 @@ def _forward(self, x): # calculate similarity if self.multi_query: - sim = einsum("b h s d, b j d -> b h s j", q, k) + #sim = einsum("b h s d, b j d -> b h s j", q, k) + sim = matmul(q, k.transpose(1,2)) else: # s and n here refer to sequence length # n is used only because einsum cannot have 2 same notations - sim = einsum("b h s d, b h n d -> b h s n", q, k) + #sim = einsum("b h s d, b h n d -> b h s n", q, k) + sim = matmul(q, k.transpose(2,3)) # apply casual mask causal_mask = self.get_mask(seq_length, device) @@ -148,9 +150,11 @@ def _forward(self, x): # aggregate values if self.multi_query: - attn_out = einsum("b h i j, b j d -> b h i d", attn, v) + #attn_out = einsum("b h i j, b j d -> b h i d", attn, v) + attn_out = matmul(attn, v) else: - attn_out = einsum("b h s n, b h n d -> b h s d", attn, v) + #attn_out = einsum("b h s n, b h n d -> b h s d", attn, v) + attn_out = matmul(attn, v) # merge heads attn_out = rearrange(attn_out, "b h s d -> b s (h d)") diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..e658873 --- /dev/null +++ b/run.sh @@ -0,0 +1,3 @@ +export TOKENIZER=./token +export DATA=./wiki_dataset +env OMP_NUM_THREADS=12 torchrun --nproc_per_node 4 --master_port 29501 train.py --from_torch --config ./configs/palm_30b_2d.py \ No newline at end of file diff --git a/train.py b/train.py index b9e51d5..61d6330 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,9 @@ from asyncio.log import logger import contextlib import os +from packaging import version +from functools import partial +from time import time import colossalai import torch @@ -9,8 +12,12 @@ from colossalai.trainer import Trainer, hooks from colossalai.utils import MultiTimer, get_current_device from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP from colossalai.context import ParallelMode +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer from data import build_data from model import build_loss, build_model @@ -24,6 +31,35 @@ def limit_cuda_memory(size_in_GB: int): colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) logger = get_dist_logger() logger.info("Using {} GB of GPU memory".format(size_in_GB)) + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model def train_palm(): assert torch.cuda.is_available() @@ -55,47 +91,35 @@ def train_palm(): assert hasattr(gpc.config, "NUM_EPOCHS"), "Please provide NUM_EPOCHS in your configuration" use_zero = hasattr(gpc.config, "zero") - ctx = contextlib.nullcontext() + #ctx = contextlib.nullcontext() tflop = 0 + default_pg = ProcessGroup(tp_degree=gpc.config.TPDEGREE) + default_dist_spec = ShardSpec([-1], [gpc.config.TPDEGREE]) if gpc.config.USE_SHARD_INIT else None if use_zero: - ctx = ZeroInitContext( - target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True, - ) + # ctx = ZeroInitContext( + # target_device=torch.cuda.current_device(), + # shard_strategy=gpc.config.zero.model_config.shard_strategy, + # shard_param=True, + # ) + ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) + with ctx: + model = build_model() + model = AutoregressiveWrapper(model) + + logger = get_dist_logger() - if hasattr(gpc.config, "LOG_PATH"): - log_path = gpc.config.LOG_PATH - logger.log_to_file(log_path) + # if hasattr(gpc.config, "LOG_PATH"): + # log_path = gpc.config.LOG_PATH + # logger.log_to_file(log_path) - with ctx: - model = build_model() - model = AutoregressiveWrapper(model) + # with ctx: + # model = build_model() + # model = AutoregressiveWrapper(model) seq_len=gpc.config.SEQ_LENGTH batch_size=gpc.config.BATCH_SIZE - # numel is a model elem in a DP process. - numel = 0 - if use_zero: - numel = ctx.model_numel_tensor.item() - else: - numel = calc_local_model_size(model) - - tflop = numel * batch_size * seq_len \ - * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) - - if numel < 1e9: - msg = f"{numel / 1e6:.3f} M" - else: - msg = f"{numel / 1e9:.3f} B" - - model_mem = torch.cuda.max_memory_allocated(get_current_device()) / 1024**3 - - logger.info("Model is built.", ranks=[0]) - logger.info(f"Parameter size = {msg} | Model memory = {model_mem:.3f} GB.", ranks=[0]) - criterion = build_loss() logger.info("Loss is built.", ranks=[0]) @@ -110,13 +134,21 @@ def train_palm(): # We use a fast CPU Adam here # If we set cpu_offload=True in optimizer_config - use_cpu_adam = ( - hasattr(gpc.config, "zero") - and hasattr(gpc.config.zero, "model_config") - and getattr(gpc.config.zero.model_config, "tensor_placement_policy") != "cuda" - ) - optimizer = HybridAdam if use_cpu_adam else torch.optim.AdamW - optimizer = optimizer(model.parameters(), lr=0.001, weight_decay=1e-2) + # use_cpu_adam = ( + # hasattr(gpc.config, "zero") + # and hasattr(gpc.config.zero, "model_config") + # and getattr(gpc.config.zero.model_config, "tensor_placement_policy") != "cuda" + # ) + # optimizer = HybridAdam if use_cpu_adam else torch.optim.AdamW + # optimizer = optimizer(model.parameters(), lr=0.001, weight_decay=1e-2) + pg = default_pg + # Tensor Parallelism (TP) + #tensor_parallelize(model, pg) + # Gemini + ZeRO DP, Note it must be used after TP + model = gemini_zero_dpp(model, pg, gpc.config.placement) + + # build optimizer + optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) # total_steps = gpc.config.NUM_EPOCHS * len(train_dataloader) # warmup_steps = getattr(gpc.config, "WARMUP_EPOCHS", 0) * len(train_dataloader) @@ -124,52 +156,54 @@ def train_palm(): logger.info("Optimizer is built.", ranks=[0]) - engine, train_dataloader, _, _ = colossalai.initialize( - model=model, - optimizer=optimizer, - criterion=criterion, - # lr_scheduler=lr_scheduler, - train_dataloader=train_dataloader, - ) + #logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) - def batch_data_process_func(batch_data): - data = batch_data["input_ids"] - labels = batch_data["labels"] - return data, labels - - engine.schedule.batch_data_process_func = batch_data_process_func - - timer = MultiTimer() - trainer = Trainer(engine=engine, logger=logger, timer=timer) - - hook_list = [ - hooks.LogMetricByEpochHook(logger=logger), - hooks.LogMetricByStepHook(), - hooks.LossHook(), - hooks.ThroughputHook(ignored_steps=10, tflop_per_step = tflop), - # hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - hooks.LogMemoryByEpochHook(logger), - # hooks.SaveCheckpointHook(checkpoint_dir="./palm.ckpt", model=model), - ] - - logger.info("Training start.", ranks=[0]) - trainer.fit( - train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - max_steps=20, - hooks=hook_list, - return_output_label=False, - display_progress=True, - ) + #numel is a model elem in a DP process. + numel = 0 + if use_zero: + #numel = ctx.model_numel_tensor.item() + numel = sum([p.numel() for p in model.parameters()]) + else: + numel = calc_local_model_size(model) + + tflop = numel * batch_size * seq_len \ + * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) - opt_state = engine.optimizer.state_dict() - if isinstance(engine.optimizer, colossalai.amp.naive_amp.NaiveAMPOptimizer): - opt_state = opt_state['optimizer'] - os_mem = calc_mem(opt_state) - logger.info(f"{engine.optimizer.__class__.__name__} state memory usage = {os_mem / 1024**2:.3f} MB", ranks=[0]) + get_tflops_func = partial(get_tflops, numel, batch_size, seq_len) - gpc.destroy() - logger.info("Training complete.", ranks=[0]) + if numel < 1e9: + msg = f"{numel / 1e6:.3f} M" + else: + msg = f"{numel / 1e9:.3f} B" + + model_mem = torch.cuda.max_memory_allocated(get_current_device()) / 1024**3 + + logger.info("Model is built.", ranks=[0]) + logger.info(f"Parameter size = {msg} | Model memory = {model_mem:.3f} GB.", ranks=[0]) + + torch.cuda.synchronize() + model.train() + for n in range(gpc.config.NUM_EPOCHS): + for step, batch in enumerate(train_dataloader): + input_ids = batch["input_ids"].cuda() + labels = batch["labels"].cuda() + optimizer.zero_grad() + start = time() + outputs = model(input_ids) + loss = criterion(outputs, labels) + #logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Forward '), ranks=[0]) + + optimizer.backward(loss) + + #logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Backward '), ranks=[0]) + optimizer.step() + #logger.info(get_mem_info(prefix=f'[{n+1}/{gpc.config.NUM_EPOCHS}] Optimizer step '), ranks=[0]) + step_time = time() - start + logger.info( + f'[{n+1}/{gpc.config.NUM_EPOCHS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', + ranks=[0]) + + torch.cuda.synchronize() if __name__ == "__main__":