Skip to content

Commit

Permalink
Add geneformer (#7)
Browse files Browse the repository at this point in the history
* add geneformer

* add to workflow

* inherit from base method

* fix variable name

* fix mapping

* cleanup script

* update info

* minor changes
  • Loading branch information
sainirmayi authored Jan 6, 2025
1 parent 359f30c commit ec64db5
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/methods/geneformer/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
__merge__: ../../api/base_method.yaml

name: geneformer
label: Geneformer
summary: Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
description: |
Geneformer is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. Here, a pre-trained Geneformer model is fine-tuned and used to predict cell type labels for an unlabelled dataset.
info:
preferred_normalization: counts

references:
doi:
- 10.1038/s41586-023-06139-9
- 10.1101/2024.08.16.608180

links:
documentation: https://geneformer.readthedocs.io/en/latest/index.html
repository: https://huggingface.co/ctheodoris/Geneformer

arguments:
- name: "--model"
type: "string"
description: String representing the Geneformer model to use
choices: ["gf-6L-30M-i2048", "gf-12L-30M-i2048", "gf-12L-95M-i4096", "gf-20L-95M-i4096"]
default: "gf-12L-95M-i4096"

resources:
- type: python_script
path: script.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pip:
- pyarrow<15.0.0a0,>=14.0.1
- huggingface_hub
- git+https://huggingface.co/ctheodoris/Geneformer.git

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
213 changes: 213 additions & 0 deletions src/methods/geneformer/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import os
from tempfile import TemporaryDirectory
import anndata as ad
from geneformer import Classifier, TranscriptomeTokenizer, DataCollatorForCellClassification
from huggingface_hub import hf_hub_download
import numpy as np
import datasets
import pickle
from transformers import BertForSequenceClassification, Trainer

## VIASH START
par = {
'input_train': 'resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad',
'input_test': 'resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad',
'output': 'output.h5ad',
"model": "gf-12L-95M-i4096",
}
meta = {
'name': 'geneformer'
}
## VIASH END

n_processors = os.cpu_count()

print('>>> Reading input files', flush=True)
input_train = ad.read_h5ad(par['input_train'])
input_test = ad.read_h5ad(par['input_test'])

if input_train.uns["dataset_organism"] != "homo_sapiens":
raise ValueError(
f"Geneformer can only be used with human data "
f"(dataset_organism == '{input_train.uns['dataset_organism']}')"
)

is_ensembl = all(var_name.startswith("ENSG") for var_name in input_train.var_names)
if not is_ensembl:
raise ValueError(f"Geneformer requires input_train.var_names to contain ENSEMBL gene ids")

print(f">>> Getting settings for model '{par['model']}'...", flush=True)
model_split = par["model"].split("-")
model_details = {
"layers": model_split[1],
"dataset": model_split[2],
"input_size": int(model_split[3][1:]),
}
print(model_details, flush=True)

print(">>> Getting model dictionary files...", flush=True)
if model_details["dataset"] == "95M":
dictionaries_subfolder = "geneformer"
elif model_details["dataset"] == "30M":
dictionaries_subfolder = "geneformer/gene_dictionaries_30m"
else:
raise ValueError(f"Invalid model dataset: {model_details['dataset']}")
print(f"Dictionaries subfolder: '{dictionaries_subfolder}'")

dictionary_files = {
"ensembl_mapping": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"ensembl_mapping_dict_gc{model_details['dataset']}.pkl",
),
"gene_median": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"gene_median_dictionary_gc{model_details['dataset']}.pkl",
),
"gene_name_id": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"gene_name_id_dict_gc{model_details['dataset']}.pkl",
),
"token": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"token_dictionary_gc{model_details['dataset']}.pkl",
),
}

print(">>> Creating working directory...", flush=True)
work_dir = TemporaryDirectory()

input_train_dir = os.path.join(work_dir.name, "input_train")
os.makedirs(input_train_dir)
tokenized_train_dir = os.path.join(work_dir.name, "tokenized_train")
os.makedirs(tokenized_train_dir)
classifier_train_dir = os.path.join(work_dir.name, "classifier_train")
os.makedirs(classifier_train_dir)
classifier_fine_tuned_dir = os.path.join(work_dir.name, "classifier_fine_tuned")
os.makedirs(classifier_fine_tuned_dir)
input_test_dir = os.path.join(work_dir.name, "input_test")
os.makedirs(input_test_dir)
tokenized_test_dir = os.path.join(work_dir.name, "tokenized_test")
os.makedirs(tokenized_test_dir)

print(f"Working directory: '{work_dir.name}'", flush=True)

print(f">>> Getting model files for model '{par['model']}'...", flush=True)
model_files = {
"model": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=par["model"],
filename="model.safetensors",
),
"config": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=par["model"],
filename="config.json",
),
}
model_dir = os.path.dirname(model_files["model"])

print(">>> Preparing input data...", flush=True)
input_train.X = input_train.layers["counts"]
input_train.var["ensembl_id"] = input_train.var["feature_id"]
input_train.obs["n_counts"] = input_train.layers["counts"].sum(axis=1)
input_train.obs["celltype"] = input_train.obs["label"]
num_types = len(input_train.obs["celltype"].unique())
input_train.write_h5ad(os.path.join(input_train_dir, "input_train.h5ad"))

input_test.X = input_test.layers["counts"]
input_test.var["ensembl_id"] = input_test.var["feature_id"]
input_test.obs["n_counts"] = input_test.layers["counts"].sum(axis=1)
input_test.write_h5ad(os.path.join(input_test_dir, "input_test.h5ad"))

print(">>> Tokenizing train data...", flush=True)
special_token = model_details["dataset"] == "95M"
print(f"Input size: {model_details['input_size']}, Special token: {special_token}")
tokenizer = TranscriptomeTokenizer(
custom_attr_name_dict={"celltype": "celltype"},
nproc=n_processors,
model_input_size=model_details["input_size"],
special_token=special_token,
gene_median_file=dictionary_files["gene_median"],
token_dictionary_file=dictionary_files["token"],
gene_mapping_file=dictionary_files["ensembl_mapping"],
)
tokenizer.tokenize_data(input_train_dir, tokenized_train_dir, "tokenized", file_format="h5ad")

print(">>> Tokenizing test data...", flush=True)
special_token = model_details["dataset"] == "95M"
print(f"Input size: {model_details['input_size']}, Special token: {special_token}")
tokenizer = TranscriptomeTokenizer(
nproc=n_processors,
model_input_size=model_details["input_size"],
special_token=special_token,
gene_median_file=dictionary_files["gene_median"],
token_dictionary_file=dictionary_files["token"],
gene_mapping_file=dictionary_files["ensembl_mapping"],
)
tokenizer.tokenize_data(input_test_dir, tokenized_test_dir, "tokenized", file_format="h5ad")

print(">>> Fine-tuning pre-trained geneformer model for cell state classification...", flush=True)
cc = Classifier(
classifier="cell",
cell_state_dict = {"state_key": "celltype", "states": "all"},
nproc=n_processors,
token_dictionary_file=dictionary_files["token"],
num_crossval_splits=1,
split_sizes={"train": 0.9, "valid": 0.1, "test": 0.0},
)

cc.prepare_data(
input_data_file=os.path.join(tokenized_train_dir, "tokenized.dataset"),
output_directory=classifier_train_dir,
output_prefix="classifier",
)

train_data = datasets.load_from_disk(classifier_train_dir + "/classifier_labeled.dataset")

cc.train_classifier(
model_directory=model_dir,
num_classes=num_types,
train_data=train_data,
eval_data=None,
output_directory=classifier_fine_tuned_dir,
predict=False
)

print(">>> Generating predictions...", flush=True)

# dictionary mapping labels from classifier to cell types
with open(f"{classifier_train_dir}/classifier_id_class_dict.pkl", "rb") as f:
id_class_dict = pickle.load(f)

with open(dictionary_files["token"], "rb") as f:
token_dict = pickle.load(f)

# Load fine-tuned model
model = BertForSequenceClassification.from_pretrained(classifier_fine_tuned_dir)

test_data = datasets.load_from_disk(tokenized_test_dir + "/tokenized.dataset")
test_data = test_data.add_column("label", [0] * len(test_data))

# Get predictions
trainer = Trainer(model=model, data_collator=DataCollatorForCellClassification(token_dictionary=token_dict))
predictions = trainer.predict(test_data)

# Select the most likely cell type based on the probability vector from the predictions of each cell
predicted_label_ids = np.argmax(predictions.predictions, axis=1)
predicted_logits = [predictions.predictions[i][predicted_label_ids[i]] for i in range(len(predicted_label_ids))]
input_test.obs['label_pred'] = [id_class_dict[p] for p in predicted_label_ids]

print(">>> Write output AnnData to file", flush=True)
output = ad.AnnData(
obs=input_test.obs[["label_pred"]],
uns={
'method_id': meta['name'],
'dataset_id': input_test.uns['dataset_id'],
'normalization_id': input_test.uns['normalization_id']
}
)
output.write_h5ad(par['output'], compression='gzip')
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ dependencies:
- name: control_methods/majority_vote
- name: control_methods/random_labels
- name: control_methods/true_labels
- name: methods/geneformer
- name: methods/knn
- name: methods/logistic_regression
- name: methods/mlp
Expand Down
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ methods = [
majority_vote,
random_labels,
true_labels,
geneformer,
knn,
logistic_regression,
mlp,
Expand Down

0 comments on commit ec64db5

Please sign in to comment.