Skip to content

Change how guide assignment model saves params from uns to var #758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def assign_mixture_model(
multiple_grna_assigned_key: str = "multiple",
multiple_grna_assignment_string: str = "+",
only_return_results: bool = False,
uns_key: str = "guide_assignment_params",
show_progress: bool = False,
**mixture_model_kwargs,
) -> np.ndarray | None:
Expand All @@ -227,7 +226,6 @@ def assign_mixture_model(
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
uns_key: Key to store guide assignment parameters in.
show_progress: Whether to shows progress bar.
mixture_model_kwargs: Are passed to the mixture model.

Expand All @@ -243,11 +241,6 @@ def assign_mixture_model(
else:
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")

if uns_key not in adata.uns:
adata.uns[uns_key] = {}
elif type(adata.uns[uns_key]) is not dict:
raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")

res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
fct = track if show_progress else lambda iterable: iterable
for gene in fct(adata.var_names):
Expand All @@ -271,7 +264,18 @@ def assign_mixture_model(
data = np.log2(data)
assignments = mixture_model.run_model(data)
res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
adata.uns[uns_key][gene] = mixture_model.params

# Add the parameters to the adata.var DataFrame
for params_name, param in mixture_model.params.items():
if param.ndim == 0:
if params_name not in adata.var.columns:
adata.var[params_name] = np.nan
adata.var.loc[gene, params_name] = param.item()
else:
for i, p in enumerate(param):
if f"{params_name}_{i}" not in adata.var.columns:
adata.var[f"{params_name}_{i}"] = np.nan
adata.var.loc[gene, f"{params_name}_{i}"] = p

# Assign guides to cells
# Some cells might have multiple guides assigned
Expand Down
Loading