generated from openproblems-bio/task_template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add geneformer * add to workflow * inherit from base method * fix variable name * fix mapping * cleanup script * update info * minor changes
- Loading branch information
1 parent
359f30c
commit ec64db5
Showing
4 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
__merge__: ../../api/base_method.yaml | ||
|
||
name: geneformer | ||
label: Geneformer | ||
summary: Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology. | ||
description: | | ||
Geneformer is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. Here, a pre-trained Geneformer model is fine-tuned and used to predict cell type labels for an unlabelled dataset. | ||
info: | ||
preferred_normalization: counts | ||
|
||
references: | ||
doi: | ||
- 10.1038/s41586-023-06139-9 | ||
- 10.1101/2024.08.16.608180 | ||
|
||
links: | ||
documentation: https://geneformer.readthedocs.io/en/latest/index.html | ||
repository: https://huggingface.co/ctheodoris/Geneformer | ||
|
||
arguments: | ||
- name: "--model" | ||
type: "string" | ||
description: String representing the Geneformer model to use | ||
choices: ["gf-6L-30M-i2048", "gf-12L-30M-i2048", "gf-12L-95M-i4096", "gf-20L-95M-i4096"] | ||
default: "gf-12L-95M-i4096" | ||
|
||
resources: | ||
- type: python_script | ||
path: script.py | ||
|
||
engines: | ||
- type: docker | ||
image: openproblems/base_pytorch_nvidia:1.0.0 | ||
setup: | ||
- type: python | ||
pip: | ||
- pyarrow<15.0.0a0,>=14.0.1 | ||
- huggingface_hub | ||
- git+https://huggingface.co/ctheodoris/Geneformer.git | ||
|
||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [midtime, midmem, midcpu, gpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import os | ||
from tempfile import TemporaryDirectory | ||
import anndata as ad | ||
from geneformer import Classifier, TranscriptomeTokenizer, DataCollatorForCellClassification | ||
from huggingface_hub import hf_hub_download | ||
import numpy as np | ||
import datasets | ||
import pickle | ||
from transformers import BertForSequenceClassification, Trainer | ||
|
||
## VIASH START | ||
par = { | ||
'input_train': 'resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad', | ||
'input_test': 'resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad', | ||
'output': 'output.h5ad', | ||
"model": "gf-12L-95M-i4096", | ||
} | ||
meta = { | ||
'name': 'geneformer' | ||
} | ||
## VIASH END | ||
|
||
n_processors = os.cpu_count() | ||
|
||
print('>>> Reading input files', flush=True) | ||
input_train = ad.read_h5ad(par['input_train']) | ||
input_test = ad.read_h5ad(par['input_test']) | ||
|
||
if input_train.uns["dataset_organism"] != "homo_sapiens": | ||
raise ValueError( | ||
f"Geneformer can only be used with human data " | ||
f"(dataset_organism == '{input_train.uns['dataset_organism']}')" | ||
) | ||
|
||
is_ensembl = all(var_name.startswith("ENSG") for var_name in input_train.var_names) | ||
if not is_ensembl: | ||
raise ValueError(f"Geneformer requires input_train.var_names to contain ENSEMBL gene ids") | ||
|
||
print(f">>> Getting settings for model '{par['model']}'...", flush=True) | ||
model_split = par["model"].split("-") | ||
model_details = { | ||
"layers": model_split[1], | ||
"dataset": model_split[2], | ||
"input_size": int(model_split[3][1:]), | ||
} | ||
print(model_details, flush=True) | ||
|
||
print(">>> Getting model dictionary files...", flush=True) | ||
if model_details["dataset"] == "95M": | ||
dictionaries_subfolder = "geneformer" | ||
elif model_details["dataset"] == "30M": | ||
dictionaries_subfolder = "geneformer/gene_dictionaries_30m" | ||
else: | ||
raise ValueError(f"Invalid model dataset: {model_details['dataset']}") | ||
print(f"Dictionaries subfolder: '{dictionaries_subfolder}'") | ||
|
||
dictionary_files = { | ||
"ensembl_mapping": hf_hub_download( | ||
repo_id="ctheodoris/Geneformer", | ||
subfolder=dictionaries_subfolder, | ||
filename=f"ensembl_mapping_dict_gc{model_details['dataset']}.pkl", | ||
), | ||
"gene_median": hf_hub_download( | ||
repo_id="ctheodoris/Geneformer", | ||
subfolder=dictionaries_subfolder, | ||
filename=f"gene_median_dictionary_gc{model_details['dataset']}.pkl", | ||
), | ||
"gene_name_id": hf_hub_download( | ||
repo_id="ctheodoris/Geneformer", | ||
subfolder=dictionaries_subfolder, | ||
filename=f"gene_name_id_dict_gc{model_details['dataset']}.pkl", | ||
), | ||
"token": hf_hub_download( | ||
repo_id="ctheodoris/Geneformer", | ||
subfolder=dictionaries_subfolder, | ||
filename=f"token_dictionary_gc{model_details['dataset']}.pkl", | ||
), | ||
} | ||
|
||
print(">>> Creating working directory...", flush=True) | ||
work_dir = TemporaryDirectory() | ||
|
||
input_train_dir = os.path.join(work_dir.name, "input_train") | ||
os.makedirs(input_train_dir) | ||
tokenized_train_dir = os.path.join(work_dir.name, "tokenized_train") | ||
os.makedirs(tokenized_train_dir) | ||
classifier_train_dir = os.path.join(work_dir.name, "classifier_train") | ||
os.makedirs(classifier_train_dir) | ||
classifier_fine_tuned_dir = os.path.join(work_dir.name, "classifier_fine_tuned") | ||
os.makedirs(classifier_fine_tuned_dir) | ||
input_test_dir = os.path.join(work_dir.name, "input_test") | ||
os.makedirs(input_test_dir) | ||
tokenized_test_dir = os.path.join(work_dir.name, "tokenized_test") | ||
os.makedirs(tokenized_test_dir) | ||
|
||
print(f"Working directory: '{work_dir.name}'", flush=True) | ||
|
||
print(f">>> Getting model files for model '{par['model']}'...", flush=True) | ||
model_files = { | ||
"model": hf_hub_download( | ||
repo_id="ctheodoris/Geneformer", | ||
subfolder=par["model"], | ||
filename="model.safetensors", | ||
), | ||
"config": hf_hub_download( | ||
repo_id="ctheodoris/Geneformer", | ||
subfolder=par["model"], | ||
filename="config.json", | ||
), | ||
} | ||
model_dir = os.path.dirname(model_files["model"]) | ||
|
||
print(">>> Preparing input data...", flush=True) | ||
input_train.X = input_train.layers["counts"] | ||
input_train.var["ensembl_id"] = input_train.var["feature_id"] | ||
input_train.obs["n_counts"] = input_train.layers["counts"].sum(axis=1) | ||
input_train.obs["celltype"] = input_train.obs["label"] | ||
num_types = len(input_train.obs["celltype"].unique()) | ||
input_train.write_h5ad(os.path.join(input_train_dir, "input_train.h5ad")) | ||
|
||
input_test.X = input_test.layers["counts"] | ||
input_test.var["ensembl_id"] = input_test.var["feature_id"] | ||
input_test.obs["n_counts"] = input_test.layers["counts"].sum(axis=1) | ||
input_test.write_h5ad(os.path.join(input_test_dir, "input_test.h5ad")) | ||
|
||
print(">>> Tokenizing train data...", flush=True) | ||
special_token = model_details["dataset"] == "95M" | ||
print(f"Input size: {model_details['input_size']}, Special token: {special_token}") | ||
tokenizer = TranscriptomeTokenizer( | ||
custom_attr_name_dict={"celltype": "celltype"}, | ||
nproc=n_processors, | ||
model_input_size=model_details["input_size"], | ||
special_token=special_token, | ||
gene_median_file=dictionary_files["gene_median"], | ||
token_dictionary_file=dictionary_files["token"], | ||
gene_mapping_file=dictionary_files["ensembl_mapping"], | ||
) | ||
tokenizer.tokenize_data(input_train_dir, tokenized_train_dir, "tokenized", file_format="h5ad") | ||
|
||
print(">>> Tokenizing test data...", flush=True) | ||
special_token = model_details["dataset"] == "95M" | ||
print(f"Input size: {model_details['input_size']}, Special token: {special_token}") | ||
tokenizer = TranscriptomeTokenizer( | ||
nproc=n_processors, | ||
model_input_size=model_details["input_size"], | ||
special_token=special_token, | ||
gene_median_file=dictionary_files["gene_median"], | ||
token_dictionary_file=dictionary_files["token"], | ||
gene_mapping_file=dictionary_files["ensembl_mapping"], | ||
) | ||
tokenizer.tokenize_data(input_test_dir, tokenized_test_dir, "tokenized", file_format="h5ad") | ||
|
||
print(">>> Fine-tuning pre-trained geneformer model for cell state classification...", flush=True) | ||
cc = Classifier( | ||
classifier="cell", | ||
cell_state_dict = {"state_key": "celltype", "states": "all"}, | ||
nproc=n_processors, | ||
token_dictionary_file=dictionary_files["token"], | ||
num_crossval_splits=1, | ||
split_sizes={"train": 0.9, "valid": 0.1, "test": 0.0}, | ||
) | ||
|
||
cc.prepare_data( | ||
input_data_file=os.path.join(tokenized_train_dir, "tokenized.dataset"), | ||
output_directory=classifier_train_dir, | ||
output_prefix="classifier", | ||
) | ||
|
||
train_data = datasets.load_from_disk(classifier_train_dir + "/classifier_labeled.dataset") | ||
|
||
cc.train_classifier( | ||
model_directory=model_dir, | ||
num_classes=num_types, | ||
train_data=train_data, | ||
eval_data=None, | ||
output_directory=classifier_fine_tuned_dir, | ||
predict=False | ||
) | ||
|
||
print(">>> Generating predictions...", flush=True) | ||
|
||
# dictionary mapping labels from classifier to cell types | ||
with open(f"{classifier_train_dir}/classifier_id_class_dict.pkl", "rb") as f: | ||
id_class_dict = pickle.load(f) | ||
|
||
with open(dictionary_files["token"], "rb") as f: | ||
token_dict = pickle.load(f) | ||
|
||
# Load fine-tuned model | ||
model = BertForSequenceClassification.from_pretrained(classifier_fine_tuned_dir) | ||
|
||
test_data = datasets.load_from_disk(tokenized_test_dir + "/tokenized.dataset") | ||
test_data = test_data.add_column("label", [0] * len(test_data)) | ||
|
||
# Get predictions | ||
trainer = Trainer(model=model, data_collator=DataCollatorForCellClassification(token_dictionary=token_dict)) | ||
predictions = trainer.predict(test_data) | ||
|
||
# Select the most likely cell type based on the probability vector from the predictions of each cell | ||
predicted_label_ids = np.argmax(predictions.predictions, axis=1) | ||
predicted_logits = [predictions.predictions[i][predicted_label_ids[i]] for i in range(len(predicted_label_ids))] | ||
input_test.obs['label_pred'] = [id_class_dict[p] for p in predicted_label_ids] | ||
|
||
print(">>> Write output AnnData to file", flush=True) | ||
output = ad.AnnData( | ||
obs=input_test.obs[["label_pred"]], | ||
uns={ | ||
'method_id': meta['name'], | ||
'dataset_id': input_test.uns['dataset_id'], | ||
'normalization_id': input_test.uns['normalization_id'] | ||
} | ||
) | ||
output.write_h5ad(par['output'], compression='gzip') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ methods = [ | |
majority_vote, | ||
random_labels, | ||
true_labels, | ||
geneformer, | ||
knn, | ||
logistic_regression, | ||
mlp, | ||
|