Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Checkpoint removal 2 #250

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchbiggraph/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def model_state_dict_private_to_public(
) -> Dict[str, ModelParameter]:
public_state_dict: Dict[str, ModelParameter] = {}
for private_name, tensor in private_state_dict.items():
if 'emb' in private_name:
continue
if not isinstance(tensor, torch.Tensor):
raise RuntimeError(
"Isn't the state dict supposed to be "
Expand Down
7 changes: 6 additions & 1 deletion torchbiggraph/graph_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def load_chunk_of_edges(
shared: bool = False,
) -> EdgeList:
file_path = self.get_edges_file(lhs_p, rhs_p)
bin_path = self.path / f"edges_{lhs_p}_{rhs_p}_{chunk_idx}.pt"
if bin_path.is_file():
return torch.load(open(bin_path, mode='rb'))
try:
with h5py.File(file_path, "r") as hf:
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
Expand Down Expand Up @@ -453,9 +456,11 @@ def load_chunk_of_edges(
)
else:
weight = None
return EdgeList(
ret = EdgeList(
EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel, weight
)
torch.save(ret, open(bin_path, mode='wb'))
return ret
except OSError as err:
# h5py refuses to make it easy to figure out what went wrong. The errno
# attribute is set to None. See https://github.com/h5py/h5py/issues/493.
Expand Down
73 changes: 61 additions & 12 deletions torchbiggraph/train_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,9 @@ def train(self) -> None:
eval_stats_after,
eval_stats_chunk_avg,
)

first = True
for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
epoch_start = time.perf_counter()
logger.info(
f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
Expand Down Expand Up @@ -600,8 +601,17 @@ def train(self) -> None:
bucket_logger = BucketLogger(logger, bucket=cur_b)
self.bucket_logger = bucket_logger

io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats)
io_bytes = 0
if first:
start = time.perf_counter()
io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats)
end = time.perf_counter()
logger.debug(f"Loading embedings took {(end - start):.2f} seconds")
first = False
start = time.perf_counter()
self.model.set_all_embeddings(holder, cur_b)
end = time.perf_counter()
logger.debug(f"Setting all embeddings took {(end - start):.2f} seconds")

current_index = (
(iteration_manager.iteration_idx + 1) * total_buckets
Expand All @@ -610,13 +620,16 @@ def train(self) -> None:
)

bucket_logger.debug("Loading edges")
start = time.perf_counter()
edges = edge_storage.load_chunk_of_edges(
cur_b.lhs,
cur_b.rhs,
edge_chunk_idx,
iteration_manager.num_edge_chunks,
shared=True,
)
end = time.perf_counter()
logger.debug(f"Loading edges took {(end - start):.2f} seconds")
num_edges = len(edges)

# this might be off in the case of tensorlist or extra edge fields
Expand Down Expand Up @@ -686,7 +699,9 @@ def train(self) -> None:
f"io: {io_time:.2f} s for {io_bytes:,} bytes ( {io_bytes / io_time / 1e6:.2f} MB/sec )"
)

self.model.clear_all_embeddings()
if total_buckets > 1:
logger.info("Clearing all embeddings")
self.model.clear_all_embeddings()

cur_stats = BucketStats(
lhs_partition=cur_b.lhs,
Expand All @@ -698,16 +713,32 @@ def train(self) -> None:
)

# release the final bucket
self._swap_partitioned_embeddings(cur_b, None, cur_stats)

final: bool = (epoch_idx + 1 == iteration_manager.num_epochs) \
and (edge_path_idx + 1 == iteration_manager.num_edge_paths) \
and (edge_chunk_idx + 1 == iteration_manager.num_edge_chunks)
to_write: bool = (final == True) or ((epoch_idx + 1) % 5 == 0 and edge_chunk_idx == 0)
if to_write:
logger.debug("Nondestructively writing the embeddings")
start = time.perf_counter()
self._nondestructive_write_embedding(cur_b)
end = time.perf_counter()
logger.debug(f"Writing embeddings took {(end - start):.2f} seconds")

self._write_stats(cur_b, cur_stats)
# self._swap_partitioned_embeddings(cur_b, None, cur_stats, to_write)

# Distributed Processing: all machines can leave the barrier now.
self._barrier()

current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1

start = time.perf_counter()
self._maybe_write_checkpoint(
epoch_idx, edge_path_idx, edge_chunk_idx, current_index
)
end = time.perf_counter()
logger.debug(f"Writing checkpoint took {(end - start):.2f} seconds")
logger.debug(f"Epoch took {(end - start):.2f} seconds")

# now we're sure that all partition files exist,
# so be strict about loading them
Expand Down Expand Up @@ -770,11 +801,35 @@ def _load_embeddings(
optimizer.load_state_dict(optim_state)
return embs, optimizer

def _write_single_embedding(
self,
holder: EmbeddingHolder,
entity: EntityName,
part: Partition):
embs = holder.partitioned_embeddings[(entity, part)]
optimizer = self.trainer.partitioned_optimizers[(entity, part)]
self.checkpoint_manager.write(
entity, part, embs.detach(), optimizer.state_dict()
)

def _nondestructive_write_embedding(self, bucket: Bucket):
parts: Set[Tuple[EntityName, Partition]] = set()
parts.update((e, bucket.lhs) for e in self.holder.lhs_partitioned_types)
parts.update((e, bucket.rhs) for e in self.holder.rhs_partitioned_types)
for entity, part in parts:
self._write_single_embedding(self.holder, entity, part)

def _write_stats(self, bucket: Optional[Bucket], stats: Optional[BucketStats]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This naming is misleading... I think it does more than write stats in the distributed scheduler.

if bucket is not None:
if stats is not None:
self.bucket_scheduler.release_bucket(bucket, stats)

def _swap_partitioned_embeddings(
self,
old_b: Optional[Bucket],
new_b: Optional[Bucket],
old_stats: Optional[BucketStats],
write: bool = True,
) -> int:
io_bytes = 0
logger.info(f"Swapping partitioned embeddings {old_b} {new_b}")
Expand All @@ -797,19 +852,13 @@ def _swap_partitioned_embeddings(
logger.info("Saving partitioned embeddings to checkpoint")
for entity, part in old_parts - new_parts:
logger.debug(f"Saving ({entity} {part})")
embs = holder.partitioned_embeddings.pop((entity, part))
optimizer = self.trainer.partitioned_optimizers.pop((entity, part))
self.checkpoint_manager.write(
entity, part, embs.detach(), optimizer.state_dict()
)
self._write_single_embedding(holder, entity, part)
self.embedding_storage_freelist[entity].add(embs.storage())
io_bytes += embs.numel() * embs.element_size() # ignore optim state
# these variables are holding large objects; let them be freed
del embs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do these lines work if you don't define embs and optimizer any more?

del optimizer

self.bucket_scheduler.release_bucket(old_b, old_stats)

if new_b is not None:
logger.info("Loading partitioned embeddings from checkpoint")
for entity, part in new_parts - old_parts:
Expand Down