Skip to content

Commit

Permalink
polish doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
fjxmlzn committed Jan 14, 2025
1 parent 8aba343 commit 7b1cd20
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 15 deletions.
6 changes: 3 additions & 3 deletions pe/api/image/improved_diffusion_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def random_api(self, label_info, num_samples):
:param num_samples: The number of random samples to generate
:type num_samples: int
:return: The data object of the generated synthetic data
:rtype: :py:class:`pe.data.data.Data`
:rtype: :py:class:`pe.data.Data`
"""
label_name = label_info.name
execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}")
Expand Down Expand Up @@ -164,9 +164,9 @@ def variation_api(self, syn_data):
"""Generating variations of the synthetic data.
:param syn_data: The data object of the synthetic data
:type syn_data: :py:class:`pe.data.data.Data`
:type syn_data: :py:class:`pe.data.Data`
:return: The data object of the variation of the input synthetic data
:rtype: :py:class:`pe.data.data.Data`
:rtype: :py:class:`pe.data.Data`
"""
execution_logger.info(f"VARIATION API: creating variations for {len(syn_data.data_frame)} samples")
images = np.stack(syn_data.data_frame[IMAGE_DATA_COLUMN_NAME].values)
Expand Down
4 changes: 2 additions & 2 deletions pe/embedding/text/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def compute_embedding(self, data):
"""Compute the Sentence Transformers embedding of text.
:param data: The data object containing the text
:type data: :py:class:`pe.data.data.Data`
:type data: :py:class:`pe.data.Data`
:return: The data object with the computed embedding
:rtype: :py:class:`pe.data.data.Data`
:rtype: :py:class:`pe.data.Data`
"""
uncomputed_data = self.filter_uncomputed_rows(data)
if len(uncomputed_data.data_frame) == 0:
Expand Down
8 changes: 6 additions & 2 deletions pe/llm/huggingface/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ class HuggingfaceLLM(LLM):
def __init__(self, model_name_or_path, batch_size=128, dry_run=False, **generation_args):
"""Constructor.
:param model_name_or_path: The model name or path of the Huggingface model
:param model_name_or_path: The model name or path of the Huggingface model. Note that we use the FastChat
library (https://github.com/lm-sys/FastChat) to manage the conversation template. If the conversation
template of your desired model is not available in FastChat, please register the conversation template in
the FastChat library. See the following link for an example:
https://github.com/microsoft/DPSDA/blob/main/pe/llm/huggingface/register_fastchat/gpt2.py
:type model_name_or_path: str
:param batch_size: The batch size to use for generating the responses, defaults to 128
:type batch_size: int, optional
Expand Down Expand Up @@ -91,7 +95,7 @@ def get_responses(self, requests, **generation_args):
"""Get the responses from the LLM.
:param requests: The requests
:type requests: list[:py:class:`pe.llm.request.Request`]
:type requests: list[:py:class:`pe.llm.Request`]
:param \\*\\*generation_args: The generation arguments. The priority of the generation arguments from the
highest to the lowerest is in the order of: the arguments set in the requests > the arguments passed to
this function > and the arguments passed to the constructor
Expand Down
16 changes: 8 additions & 8 deletions pe/population/pe_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
"""Constructor.
:param api: The API object that contains the random and variation APIs
:type api: :py:class:`pe.api.api.API`
:type api: :py:class:`pe.api.API`
:param histogram_threshold: The threshold for clipping the histogram. None means no clipping. Defaults to None
:type histogram_threshold: float, optional
:param initial_variation_api_fold: The number of variations to apply to the initial synthetic data, defaults to
Expand Down Expand Up @@ -61,7 +61,7 @@ def initial(self, label_info, num_samples):
:param num_samples: The number of samples to generate
:type num_samples: int
:return: The initial synthetic data
:rtype: :py:class:`pe.data.data.Data`
:rtype: :py:class:`pe.data.Data`
"""
execution_logger.info(
f"Population: generating {num_samples}*{self._initial_variation_api_fold + 1} initial "
Expand All @@ -83,10 +83,10 @@ def _post_process_histogram(self, syn_data):
"""Post process the histogram of synthetic data (e.g., clipping).
:param syn_data: The synthetic data
:type syn_data: :py:class:`pe.data.data.Data`
:type syn_data: :py:class:`pe.data.Data`
:return: The synthetic data with post-processed histogram in the column
:py:const:`pe.constant.data.POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME`
:rtype: :py:class:`pe.data.data.Data`
:rtype: :py:class:`pe.data.Data`
"""
count = syn_data.data_frame[DP_HISTOGRAM_COLUMN_NAME].to_numpy()
if self._histogram_threshold is not None:
Expand All @@ -101,12 +101,12 @@ def _select_data(self, syn_data, num_samples):
"""Select data from the synthetic data according to `selection_mode`.
:param syn_data: The synthetic data
:type syn_data: :py:class:`pe.data.data.Data`
:type syn_data: :py:class:`pe.data.Data`
:param num_samples: The number of samples to select
:type num_samples: int
:raises ValueError: If the selection mode is not supported
:return: The selected data
:rtype: :py:class:`pe.data.data.Data`
:rtype: :py:class:`pe.data.Data`
"""
if self._selection_mode == "sample":
count = syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME].to_numpy()
Expand All @@ -128,11 +128,11 @@ def next(self, syn_data, num_samples):
"""Generate the next synthetic data.
:param syn_data: The synthetic data
:type syn_data: :py:class:`pe.data.data.Data`
:type syn_data: :py:class:`pe.data.Data`
:param num_samples: The number of samples to generate
:type num_samples: int
:return: The next synthetic data
:rtype: :py:class:`pe.data.data.Data`
:rtype: :py:class:`pe.data.Data`
"""
execution_logger.info(
f"Population: generating {num_samples}*{self._next_variation_api_fold} " "next synthetic samples"
Expand Down

0 comments on commit 7b1cd20

Please sign in to comment.