From 7d94db1a03407d7633108bd0397eac82e02c07f3 Mon Sep 17 00:00:00 2001 From: Thomas Markovich Date: Fri, 11 Feb 2022 22:46:11 -0500 Subject: [PATCH 1/3] Filtering the embs from the state dict --- torchbiggraph/checkpoint_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchbiggraph/checkpoint_manager.py b/torchbiggraph/checkpoint_manager.py index 40402bcf..bec3922d 100644 --- a/torchbiggraph/checkpoint_manager.py +++ b/torchbiggraph/checkpoint_manager.py @@ -100,6 +100,8 @@ def model_state_dict_private_to_public( ) -> Dict[str, ModelParameter]: public_state_dict: Dict[str, ModelParameter] = {} for private_name, tensor in private_state_dict.items(): + if 'emb' in private_name: + continue if not isinstance(tensor, torch.Tensor): raise RuntimeError( "Isn't the state dict supposed to be " From 017862cb723d65a4e2d5f0347abea635d4e3eca0 Mon Sep 17 00:00:00 2001 From: Thomas Markovich Date: Fri, 11 Feb 2022 22:46:52 -0500 Subject: [PATCH 2/3] Swapped the logic to write only every k steps --- torchbiggraph/train_cpu.py | 55 +++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/torchbiggraph/train_cpu.py b/torchbiggraph/train_cpu.py index b87f0349..82aae905 100644 --- a/torchbiggraph/train_cpu.py +++ b/torchbiggraph/train_cpu.py @@ -553,7 +553,7 @@ def train(self) -> None: eval_stats_after, eval_stats_chunk_avg, ) - + FIRST = True for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager: logger.info( f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, " @@ -600,7 +600,10 @@ def train(self) -> None: bucket_logger = BucketLogger(logger, bucket=cur_b) self.bucket_logger = bucket_logger - io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats) + io_bytes = 0 + if FIRST: + io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats) + FIRST = False self.model.set_all_embeddings(holder, cur_b) current_index = ( @@ -686,7 +689,9 @@ def train(self) -> None: f"io: {io_time:.2f} s for {io_bytes:,} bytes ( {io_bytes / io_time / 1e6:.2f} MB/sec )" ) - self.model.clear_all_embeddings() + if total_buckets > 1: + logging.info("Clearing all embeddings") + self.model.clear_all_embeddings() cur_stats = BucketStats( lhs_partition=cur_b.lhs, @@ -698,13 +703,21 @@ def train(self) -> None: ) # release the final bucket - self._swap_partitioned_embeddings(cur_b, None, cur_stats) + + final: bool = (epoch_idx + 1 == iteration_manager.num_epochs) \ + and (edge_path_idx + 1 == iteration_manager.num_edge_paths) \ + and (edge_chunk_idx + 1 == iteration_manager.num_edge_chunks) + to_write: bool = (final == True) or ((epoch_idx + 1) % 10 == 0 and edge_chunk_idx == 0) + if to_write: + logging.debug("Nondestructively writing the embeddings") + self._nondestructive_write_embedding(cur_b) + self._write_stats(cur_b, cur_stats) + # self._swap_partitioned_embeddings(cur_b, None, cur_stats, to_write) # Distributed Processing: all machines can leave the barrier now. self._barrier() current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1 - self._maybe_write_checkpoint( epoch_idx, edge_path_idx, edge_chunk_idx, current_index ) @@ -770,11 +783,35 @@ def _load_embeddings( optimizer.load_state_dict(optim_state) return embs, optimizer + def _write_single_embedding( + self, + holder: EmbeddingHolder, + entity: EntityName, + part: Partition): + embs = holder.partitioned_embeddings[(entity, part)] + optimizer = self.trainer.partitioned_optimizers[(entity, part)] + self.checkpoint_manager.write( + entity, part, embs.detach(), optimizer.state_dict() + ) + + def _nondestructive_write_embedding(self, bucket: Bucket): + parts: Set[Tuple[EntityName, Partition]] = set() + parts.update((e, bucket.lhs) for e in self.holder.lhs_partitioned_types) + parts.update((e, bucket.rhs) for e in self.holder.rhs_partitioned_types) + for entity, part in parts: + self._write_single_embedding(self.holder, entity, part) + + def _write_stats(self, bucket: Optional[Bucket], stats: Optional[BucketStats]): + if bucket is not None: + if stats is not None: + self.bucket_scheduler.release_bucket(bucket, stats) + def _swap_partitioned_embeddings( self, old_b: Optional[Bucket], new_b: Optional[Bucket], old_stats: Optional[BucketStats], + write: bool = True, ) -> int: io_bytes = 0 logger.info(f"Swapping partitioned embeddings {old_b} {new_b}") @@ -797,19 +834,13 @@ def _swap_partitioned_embeddings( logger.info("Saving partitioned embeddings to checkpoint") for entity, part in old_parts - new_parts: logger.debug(f"Saving ({entity} {part})") - embs = holder.partitioned_embeddings.pop((entity, part)) - optimizer = self.trainer.partitioned_optimizers.pop((entity, part)) - self.checkpoint_manager.write( - entity, part, embs.detach(), optimizer.state_dict() - ) + self._write_single_embedding(holder, entity, part) self.embedding_storage_freelist[entity].add(embs.storage()) io_bytes += embs.numel() * embs.element_size() # ignore optim state # these variables are holding large objects; let them be freed del embs del optimizer - self.bucket_scheduler.release_bucket(old_b, old_stats) - if new_b is not None: logger.info("Loading partitioned embeddings from checkpoint") for entity, part in new_parts - old_parts: From f46539fc63a922441b6349508bc3a490934607d2 Mon Sep 17 00:00:00 2001 From: Thomas Markovich Date: Thu, 17 Feb 2022 21:41:22 -0500 Subject: [PATCH 3/3] hacks --- torchbiggraph/graph_storages.py | 7 ++++++- torchbiggraph/train_cpu.py | 30 ++++++++++++++++++++++++------ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/torchbiggraph/graph_storages.py b/torchbiggraph/graph_storages.py index 49665685..d15a2f0b 100644 --- a/torchbiggraph/graph_storages.py +++ b/torchbiggraph/graph_storages.py @@ -416,6 +416,9 @@ def load_chunk_of_edges( shared: bool = False, ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) + bin_path = self.path / f"edges_{lhs_p}_{rhs_p}_{chunk_idx}.pt" + if bin_path.is_file(): + return torch.load(open(bin_path, mode='rb')) try: with h5py.File(file_path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: @@ -453,9 +456,11 @@ def load_chunk_of_edges( ) else: weight = None - return EdgeList( + ret = EdgeList( EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel, weight ) + torch.save(ret, open(bin_path, mode='wb')) + return ret except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. diff --git a/torchbiggraph/train_cpu.py b/torchbiggraph/train_cpu.py index 82aae905..7dca9b1f 100644 --- a/torchbiggraph/train_cpu.py +++ b/torchbiggraph/train_cpu.py @@ -553,8 +553,9 @@ def train(self) -> None: eval_stats_after, eval_stats_chunk_avg, ) - FIRST = True + first = True for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager: + epoch_start = time.perf_counter() logger.info( f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, " f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, " @@ -601,10 +602,16 @@ def train(self) -> None: self.bucket_logger = bucket_logger io_bytes = 0 - if FIRST: + if first: + start = time.perf_counter() io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats) - FIRST = False + end = time.perf_counter() + logger.debug(f"Loading embedings took {(end - start):.2f} seconds") + first = False + start = time.perf_counter() self.model.set_all_embeddings(holder, cur_b) + end = time.perf_counter() + logger.debug(f"Setting all embeddings took {(end - start):.2f} seconds") current_index = ( (iteration_manager.iteration_idx + 1) * total_buckets @@ -613,6 +620,7 @@ def train(self) -> None: ) bucket_logger.debug("Loading edges") + start = time.perf_counter() edges = edge_storage.load_chunk_of_edges( cur_b.lhs, cur_b.rhs, @@ -620,6 +628,8 @@ def train(self) -> None: iteration_manager.num_edge_chunks, shared=True, ) + end = time.perf_counter() + logger.debug(f"Loading edges took {(end - start):.2f} seconds") num_edges = len(edges) # this might be off in the case of tensorlist or extra edge fields @@ -690,7 +700,7 @@ def train(self) -> None: ) if total_buckets > 1: - logging.info("Clearing all embeddings") + logger.info("Clearing all embeddings") self.model.clear_all_embeddings() cur_stats = BucketStats( @@ -707,10 +717,14 @@ def train(self) -> None: final: bool = (epoch_idx + 1 == iteration_manager.num_epochs) \ and (edge_path_idx + 1 == iteration_manager.num_edge_paths) \ and (edge_chunk_idx + 1 == iteration_manager.num_edge_chunks) - to_write: bool = (final == True) or ((epoch_idx + 1) % 10 == 0 and edge_chunk_idx == 0) + to_write: bool = (final == True) or ((epoch_idx + 1) % 5 == 0 and edge_chunk_idx == 0) if to_write: - logging.debug("Nondestructively writing the embeddings") + logger.debug("Nondestructively writing the embeddings") + start = time.perf_counter() self._nondestructive_write_embedding(cur_b) + end = time.perf_counter() + logger.debug(f"Writing embeddings took {(end - start):.2f} seconds") + self._write_stats(cur_b, cur_stats) # self._swap_partitioned_embeddings(cur_b, None, cur_stats, to_write) @@ -718,9 +732,13 @@ def train(self) -> None: self._barrier() current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1 + start = time.perf_counter() self._maybe_write_checkpoint( epoch_idx, edge_path_idx, edge_chunk_idx, current_index ) + end = time.perf_counter() + logger.debug(f"Writing checkpoint took {(end - start):.2f} seconds") + logger.debug(f"Epoch took {(end - start):.2f} seconds") # now we're sure that all partition files exist, # so be strict about loading them