diff --git a/src/methods/geneformer/config.vsh.yaml b/src/methods/geneformer/config.vsh.yaml new file mode 100644 index 0000000..582b9cd --- /dev/null +++ b/src/methods/geneformer/config.vsh.yaml @@ -0,0 +1,46 @@ +__merge__: ../../api/base_method.yaml + +name: geneformer +label: Geneformer +summary: Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology. +description: | + Geneformer is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. Here, a pre-trained Geneformer model is fine-tuned and used to predict cell type labels for an unlabelled dataset. + +info: + preferred_normalization: counts + +references: + doi: + - 10.1038/s41586-023-06139-9 + - 10.1101/2024.08.16.608180 + +links: + documentation: https://geneformer.readthedocs.io/en/latest/index.html + repository: https://huggingface.co/ctheodoris/Geneformer + +arguments: + - name: "--model" + type: "string" + description: String representing the Geneformer model to use + choices: ["gf-6L-30M-i2048", "gf-12L-30M-i2048", "gf-12L-95M-i4096", "gf-20L-95M-i4096"] + default: "gf-12L-95M-i4096" + +resources: + - type: python_script + path: script.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + pip: + - pyarrow<15.0.0a0,>=14.0.1 + - huggingface_hub + - git+https://huggingface.co/ctheodoris/Geneformer.git + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/geneformer/script.py b/src/methods/geneformer/script.py new file mode 100644 index 0000000..463c96a --- /dev/null +++ b/src/methods/geneformer/script.py @@ -0,0 +1,213 @@ +import os +from tempfile import TemporaryDirectory +import anndata as ad +from geneformer import Classifier, TranscriptomeTokenizer, DataCollatorForCellClassification +from huggingface_hub import hf_hub_download +import numpy as np +import datasets +import pickle +from transformers import BertForSequenceClassification, Trainer + +## VIASH START +par = { + 'input_train': 'resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad', + 'input_test': 'resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad', + 'output': 'output.h5ad', + "model": "gf-12L-95M-i4096", +} +meta = { + 'name': 'geneformer' +} +## VIASH END + +n_processors = os.cpu_count() + +print('>>> Reading input files', flush=True) +input_train = ad.read_h5ad(par['input_train']) +input_test = ad.read_h5ad(par['input_test']) + +if input_train.uns["dataset_organism"] != "homo_sapiens": + raise ValueError( + f"Geneformer can only be used with human data " + f"(dataset_organism == '{input_train.uns['dataset_organism']}')" + ) + +is_ensembl = all(var_name.startswith("ENSG") for var_name in input_train.var_names) +if not is_ensembl: + raise ValueError(f"Geneformer requires input_train.var_names to contain ENSEMBL gene ids") + +print(f">>> Getting settings for model '{par['model']}'...", flush=True) +model_split = par["model"].split("-") +model_details = { + "layers": model_split[1], + "dataset": model_split[2], + "input_size": int(model_split[3][1:]), +} +print(model_details, flush=True) + +print(">>> Getting model dictionary files...", flush=True) +if model_details["dataset"] == "95M": + dictionaries_subfolder = "geneformer" +elif model_details["dataset"] == "30M": + dictionaries_subfolder = "geneformer/gene_dictionaries_30m" +else: + raise ValueError(f"Invalid model dataset: {model_details['dataset']}") +print(f"Dictionaries subfolder: '{dictionaries_subfolder}'") + +dictionary_files = { + "ensembl_mapping": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"ensembl_mapping_dict_gc{model_details['dataset']}.pkl", + ), + "gene_median": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"gene_median_dictionary_gc{model_details['dataset']}.pkl", + ), + "gene_name_id": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"gene_name_id_dict_gc{model_details['dataset']}.pkl", + ), + "token": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"token_dictionary_gc{model_details['dataset']}.pkl", + ), +} + +print(">>> Creating working directory...", flush=True) +work_dir = TemporaryDirectory() + +input_train_dir = os.path.join(work_dir.name, "input_train") +os.makedirs(input_train_dir) +tokenized_train_dir = os.path.join(work_dir.name, "tokenized_train") +os.makedirs(tokenized_train_dir) +classifier_train_dir = os.path.join(work_dir.name, "classifier_train") +os.makedirs(classifier_train_dir) +classifier_fine_tuned_dir = os.path.join(work_dir.name, "classifier_fine_tuned") +os.makedirs(classifier_fine_tuned_dir) +input_test_dir = os.path.join(work_dir.name, "input_test") +os.makedirs(input_test_dir) +tokenized_test_dir = os.path.join(work_dir.name, "tokenized_test") +os.makedirs(tokenized_test_dir) + +print(f"Working directory: '{work_dir.name}'", flush=True) + +print(f">>> Getting model files for model '{par['model']}'...", flush=True) +model_files = { + "model": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=par["model"], + filename="model.safetensors", + ), + "config": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=par["model"], + filename="config.json", + ), +} +model_dir = os.path.dirname(model_files["model"]) + +print(">>> Preparing input data...", flush=True) +input_train.X = input_train.layers["counts"] +input_train.var["ensembl_id"] = input_train.var["feature_id"] +input_train.obs["n_counts"] = input_train.layers["counts"].sum(axis=1) +input_train.obs["celltype"] = input_train.obs["label"] +num_types = len(input_train.obs["celltype"].unique()) +input_train.write_h5ad(os.path.join(input_train_dir, "input_train.h5ad")) + +input_test.X = input_test.layers["counts"] +input_test.var["ensembl_id"] = input_test.var["feature_id"] +input_test.obs["n_counts"] = input_test.layers["counts"].sum(axis=1) +input_test.write_h5ad(os.path.join(input_test_dir, "input_test.h5ad")) + +print(">>> Tokenizing train data...", flush=True) +special_token = model_details["dataset"] == "95M" +print(f"Input size: {model_details['input_size']}, Special token: {special_token}") +tokenizer = TranscriptomeTokenizer( + custom_attr_name_dict={"celltype": "celltype"}, + nproc=n_processors, + model_input_size=model_details["input_size"], + special_token=special_token, + gene_median_file=dictionary_files["gene_median"], + token_dictionary_file=dictionary_files["token"], + gene_mapping_file=dictionary_files["ensembl_mapping"], +) +tokenizer.tokenize_data(input_train_dir, tokenized_train_dir, "tokenized", file_format="h5ad") + +print(">>> Tokenizing test data...", flush=True) +special_token = model_details["dataset"] == "95M" +print(f"Input size: {model_details['input_size']}, Special token: {special_token}") +tokenizer = TranscriptomeTokenizer( + nproc=n_processors, + model_input_size=model_details["input_size"], + special_token=special_token, + gene_median_file=dictionary_files["gene_median"], + token_dictionary_file=dictionary_files["token"], + gene_mapping_file=dictionary_files["ensembl_mapping"], +) +tokenizer.tokenize_data(input_test_dir, tokenized_test_dir, "tokenized", file_format="h5ad") + +print(">>> Fine-tuning pre-trained geneformer model for cell state classification...", flush=True) +cc = Classifier( + classifier="cell", + cell_state_dict = {"state_key": "celltype", "states": "all"}, + nproc=n_processors, + token_dictionary_file=dictionary_files["token"], + num_crossval_splits=1, + split_sizes={"train": 0.9, "valid": 0.1, "test": 0.0}, +) + +cc.prepare_data( + input_data_file=os.path.join(tokenized_train_dir, "tokenized.dataset"), + output_directory=classifier_train_dir, + output_prefix="classifier", +) + +train_data = datasets.load_from_disk(classifier_train_dir + "/classifier_labeled.dataset") + +cc.train_classifier( + model_directory=model_dir, + num_classes=num_types, + train_data=train_data, + eval_data=None, + output_directory=classifier_fine_tuned_dir, + predict=False +) + +print(">>> Generating predictions...", flush=True) + +# dictionary mapping labels from classifier to cell types +with open(f"{classifier_train_dir}/classifier_id_class_dict.pkl", "rb") as f: + id_class_dict = pickle.load(f) + +with open(dictionary_files["token"], "rb") as f: + token_dict = pickle.load(f) + +# Load fine-tuned model +model = BertForSequenceClassification.from_pretrained(classifier_fine_tuned_dir) + +test_data = datasets.load_from_disk(tokenized_test_dir + "/tokenized.dataset") +test_data = test_data.add_column("label", [0] * len(test_data)) + +# Get predictions +trainer = Trainer(model=model, data_collator=DataCollatorForCellClassification(token_dictionary=token_dict)) +predictions = trainer.predict(test_data) + +# Select the most likely cell type based on the probability vector from the predictions of each cell +predicted_label_ids = np.argmax(predictions.predictions, axis=1) +predicted_logits = [predictions.predictions[i][predicted_label_ids[i]] for i in range(len(predicted_label_ids))] +input_test.obs['label_pred'] = [id_class_dict[p] for p in predicted_label_ids] + +print(">>> Write output AnnData to file", flush=True) +output = ad.AnnData( + obs=input_test.obs[["label_pred"]], + uns={ + 'method_id': meta['name'], + 'dataset_id': input_test.uns['dataset_id'], + 'normalization_id': input_test.uns['normalization_id'] + } +) +output.write_h5ad(par['output'], compression='gzip') diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index a1c288e..ddc0359 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -80,6 +80,7 @@ dependencies: - name: control_methods/majority_vote - name: control_methods/random_labels - name: control_methods/true_labels + - name: methods/geneformer - name: methods/knn - name: methods/logistic_regression - name: methods/mlp diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 8d8d30a..793e482 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -11,6 +11,7 @@ methods = [ majority_vote, random_labels, true_labels, + geneformer, knn, logistic_regression, mlp,