diff --git a/pe/api/image/improved_diffusion_api.py b/pe/api/image/improved_diffusion_api.py index e8812bd..99b0c9c 100644 --- a/pe/api/image/improved_diffusion_api.py +++ b/pe/api/image/improved_diffusion_api.py @@ -131,7 +131,7 @@ def random_api(self, label_info, num_samples): :param num_samples: The number of random samples to generate :type num_samples: int :return: The data object of the generated synthetic data - :rtype: :py:class:`pe.data.data.Data` + :rtype: :py:class:`pe.data.Data` """ label_name = label_info.name execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}") @@ -164,9 +164,9 @@ def variation_api(self, syn_data): """Generating variations of the synthetic data. :param syn_data: The data object of the synthetic data - :type syn_data: :py:class:`pe.data.data.Data` + :type syn_data: :py:class:`pe.data.Data` :return: The data object of the variation of the input synthetic data - :rtype: :py:class:`pe.data.data.Data` + :rtype: :py:class:`pe.data.Data` """ execution_logger.info(f"VARIATION API: creating variations for {len(syn_data.data_frame)} samples") images = np.stack(syn_data.data_frame[IMAGE_DATA_COLUMN_NAME].values) diff --git a/pe/embedding/text/sentence_transformer.py b/pe/embedding/text/sentence_transformer.py index 0202487..1cab843 100644 --- a/pe/embedding/text/sentence_transformer.py +++ b/pe/embedding/text/sentence_transformer.py @@ -32,9 +32,9 @@ def compute_embedding(self, data): """Compute the Sentence Transformers embedding of text. :param data: The data object containing the text - :type data: :py:class:`pe.data.data.Data` + :type data: :py:class:`pe.data.Data` :return: The data object with the computed embedding - :rtype: :py:class:`pe.data.data.Data` + :rtype: :py:class:`pe.data.Data` """ uncomputed_data = self.filter_uncomputed_rows(data) if len(uncomputed_data.data_frame) == 0: diff --git a/pe/llm/huggingface/huggingface.py b/pe/llm/huggingface/huggingface.py index e8420c6..343f64b 100644 --- a/pe/llm/huggingface/huggingface.py +++ b/pe/llm/huggingface/huggingface.py @@ -12,7 +12,11 @@ class HuggingfaceLLM(LLM): def __init__(self, model_name_or_path, batch_size=128, dry_run=False, **generation_args): """Constructor. - :param model_name_or_path: The model name or path of the Huggingface model + :param model_name_or_path: The model name or path of the Huggingface model. Note that we use the FastChat + library (https://github.com/lm-sys/FastChat) to manage the conversation template. If the conversation + template of your desired model is not available in FastChat, please register the conversation template in + the FastChat library. See the following link for an example: + https://github.com/microsoft/DPSDA/blob/main/pe/llm/huggingface/register_fastchat/gpt2.py :type model_name_or_path: str :param batch_size: The batch size to use for generating the responses, defaults to 128 :type batch_size: int, optional @@ -91,7 +95,7 @@ def get_responses(self, requests, **generation_args): """Get the responses from the LLM. :param requests: The requests - :type requests: list[:py:class:`pe.llm.request.Request`] + :type requests: list[:py:class:`pe.llm.Request`] :param \\*\\*generation_args: The generation arguments. The priority of the generation arguments from the highest to the lowerest is in the order of: the arguments set in the requests > the arguments passed to this function > and the arguments passed to the constructor diff --git a/pe/population/pe_population.py b/pe/population/pe_population.py index 686b664..d6590c3 100644 --- a/pe/population/pe_population.py +++ b/pe/population/pe_population.py @@ -24,7 +24,7 @@ def __init__( """Constructor. :param api: The API object that contains the random and variation APIs - :type api: :py:class:`pe.api.api.API` + :type api: :py:class:`pe.api.API` :param histogram_threshold: The threshold for clipping the histogram. None means no clipping. Defaults to None :type histogram_threshold: float, optional :param initial_variation_api_fold: The number of variations to apply to the initial synthetic data, defaults to @@ -61,7 +61,7 @@ def initial(self, label_info, num_samples): :param num_samples: The number of samples to generate :type num_samples: int :return: The initial synthetic data - :rtype: :py:class:`pe.data.data.Data` + :rtype: :py:class:`pe.data.Data` """ execution_logger.info( f"Population: generating {num_samples}*{self._initial_variation_api_fold + 1} initial " @@ -83,10 +83,10 @@ def _post_process_histogram(self, syn_data): """Post process the histogram of synthetic data (e.g., clipping). :param syn_data: The synthetic data - :type syn_data: :py:class:`pe.data.data.Data` + :type syn_data: :py:class:`pe.data.Data` :return: The synthetic data with post-processed histogram in the column :py:const:`pe.constant.data.POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME` - :rtype: :py:class:`pe.data.data.Data` + :rtype: :py:class:`pe.data.Data` """ count = syn_data.data_frame[DP_HISTOGRAM_COLUMN_NAME].to_numpy() if self._histogram_threshold is not None: @@ -101,12 +101,12 @@ def _select_data(self, syn_data, num_samples): """Select data from the synthetic data according to `selection_mode`. :param syn_data: The synthetic data - :type syn_data: :py:class:`pe.data.data.Data` + :type syn_data: :py:class:`pe.data.Data` :param num_samples: The number of samples to select :type num_samples: int :raises ValueError: If the selection mode is not supported :return: The selected data - :rtype: :py:class:`pe.data.data.Data` + :rtype: :py:class:`pe.data.Data` """ if self._selection_mode == "sample": count = syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME].to_numpy() @@ -128,11 +128,11 @@ def next(self, syn_data, num_samples): """Generate the next synthetic data. :param syn_data: The synthetic data - :type syn_data: :py:class:`pe.data.data.Data` + :type syn_data: :py:class:`pe.data.Data` :param num_samples: The number of samples to generate :type num_samples: int :return: The next synthetic data - :rtype: :py:class:`pe.data.data.Data` + :rtype: :py:class:`pe.data.Data` """ execution_logger.info( f"Population: generating {num_samples}*{self._next_variation_api_fold} " "next synthetic samples"