Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed GPU acceleration #25

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
60da0e0
added verbose flag to reduce print statements
keeminlee Jul 15, 2024
55b52c7
merging
keeminlee Jul 15, 2024
44036fb
merged
keeminlee Jul 15, 2024
7b2e32c
gerge branch 'main' of https://github.com/paninski-lab/eks into slp-c…
keeminlee Jul 15, 2024
40ab165
more verbose print removal
keeminlee Jul 15, 2024
fa6808b
Merge branch 'slp-converter' of https://github.com/paninski-lab/eks i…
keeminlee Jul 15, 2024
ac25be8
more, more verbose print removal
keeminlee Jul 16, 2024
576622d
Merge branch 'slp-converter' of https://github.com/paninski-lab/eks i…
keeminlee Jul 16, 2024
d3e80c9
outputs nll values
keeminlee Aug 21, 2024
cf99cac
zscore threshold set
keeminlee Aug 21, 2024
8af094b
eks scalar covariance inflation, initial pytest setup
keeminlee Oct 23, 2024
aa94ac3
removed SLEAP fish workaround
keeminlee Oct 23, 2024
6900f65
merge
keeminlee Oct 23, 2024
c7e23ff
Merge branch 'slp-converter' of https://github.com/paninski-lab/eks i…
keeminlee Oct 23, 2024
167a655
added posterior var to eks output csvs
keeminlee Oct 28, 2024
64cf946
merge
keeminlee Oct 28, 2024
754206f
ens var dynamic update fix
keeminlee Oct 31, 2024
df500e9
Merge branch 'slp-converter' of https://github.com/paninski-lab/eks i…
keeminlee Oct 31, 2024
747e58f
merge
keeminlee Oct 31, 2024
2dc4789
Merge branch 'slp-converter' of https://github.com/paninski-lab/eks i…
keeminlee Oct 31, 2024
64d1ad0
removed debug prints
keeminlee Oct 31, 2024
62e53b7
fixed zscore indexing
keeminlee Nov 1, 2024
4669928
merged
keeminlee Nov 1, 2024
3149b13
removed debug print for covariance scaling
keeminlee Nov 3, 2024
6ada39b
flake8
keeminlee Nov 3, 2024
aa25a02
pytests for core functions WIP
keeminlee Nov 5, 2024
24edd5e
pytests and refactoring for cleaner file i/o
keeminlee Nov 15, 2024
e0f0534
Delete scripts/plotting_aeks.py
keeminlee Nov 20, 2024
4619ab7
Delete tests/run_tests.py
keeminlee Nov 20, 2024
00aff94
Merged from main
keeminlee Nov 20, 2024
097df32
gerge branch 'keeminlee' of https://github.com/paninski-lab/eks into …
keeminlee Nov 20, 2024
e1ecc97
added comment for E_blocks
keeminlee Nov 20, 2024
446a472
merging
keeminlee Dec 19, 2024
c254d0a
mirrored and unmirrored multicam functions
keeminlee Jan 7, 2025
61cdccc
resolving PR edit requests: more tests + indentation fix
keeminlee Jan 7, 2025
78ab00d
print type format compatibility fix for s=10
keeminlee Jan 7, 2025
c90a1d8
GPU accel removed
keeminlee Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 0 additions & 160 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,166 +528,6 @@ def single_timestep_nll(innovation, innovation_cov):
nll_increment = 0.5 * jnp.abs(log_det_S + quadratic_term + c)
return nll_increment


# ----- Parallel Functions for GPU -----

def first_filtering_element(C, A, Q, R, m0, P0, y):
# model.F = A, model.H = C,
S = C @ Q @ C.T + R
CF, low = jsc.linalg.cho_factor(S) # note the jsc

m1 = A @ m0
P1 = A @ P0 @ A.T + Q
S1 = C @ P1 @ C.T + R
K1 = jsc.linalg.solve(S1, C @ P1, assume_a='pos').T # note the jsc

A_updated = jnp.zeros_like(A)
b = m1 + K1 @ (y - C @ m1)
C_updated = P1 - K1 @ S1 @ K1.T

# note the jsc
eta = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), y)
J = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), C @ A)
return A_updated, b, C_updated, J, eta


def generic_filtering_element(C, A, Q, R, y):
S = C @ Q @ C.T + R
CF, low = jsc.linalg.cho_factor(S) # note the jsc
K = jsc.linalg.cho_solve((CF, low), C @ Q).T # note the jsc
A_updated = A - K @ C @ A
b = K @ y
C_updated = Q - K @ C @ Q

# note the jsc
eta = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), y)
J = A.T @ C.T @ jsc.linalg.cho_solve((CF, low), C @ A)
return A_updated, b, C_updated, J, eta


def make_associative_filtering_elements(C, A, Q, R, m0, P0, observations):
first_elems = first_filtering_element(C, A, Q, R, m0, P0, observations[0])
generic_elems = vmap(lambda o: generic_filtering_element(C, A, Q, R, o))(observations[1:])
return tuple(jnp.concatenate([jnp.expand_dims(first_e, 0), gen_es])
for first_e, gen_es in zip(first_elems, generic_elems))


@partial(vmap)
def filtering_operator(elem1, elem2):
# # note the jsc everywhere
A1, b1, C1, J1, eta1 = elem1
A2, b2, C2, J2, eta2 = elem2
dim = A1.shape[0]
I_var = jnp.eye(dim) # note the jnp

I_C1J2 = I_var + C1 @ J2
temp = jsc.linalg.solve(I_C1J2.T, A2.T).T
A = temp @ A1
b = temp @ (b1 + C1 @ eta2) + b2
C = temp @ C1 @ A2.T + C2

I_J2C1 = I_var + J2 @ C1
temp = jsc.linalg.solve(I_J2C1.T, A1).T

eta = temp @ (eta2 - J2 @ b1) + eta1
J = temp @ J2 @ A1 + J1

return A, b, C, J, eta


def pkf(y, m0, cov0, A, Q, C, R):
initial_elements = make_associative_filtering_elements(C, A, Q, R, m0, cov0, y)
final_elements = associative_scan(filtering_operator, initial_elements)
return final_elements


pkf_func = jit(pkf)


def get_kalman_means(A_scan, b_scan, m0):
"""
Computes the Kalman mean at a single timepoint, the result is:
A_scan @ m0 + b_scan

Returned shape: (state_dimension, 1)
"""
return A_scan @ jnp.expand_dims(m0, axis=1) + jnp.expand_dims(b_scan, axis=1)


def get_kalman_variances(C):
return C


def get_next_cov(A, C, Q, R, filter_cov, filter_mean):
"""
Given the moments of p(x_t | y_1, ..., y_t) (normal filter distribution),
compute the moments of the distribution for:
p(y_{t+1} | y_1, ..., y_t)

Params:
A (np.ndarray): Shape (state_dimension, state_dimension) Process coeff matrix
C (np.ndarray): Shape (obs_dimension, state_dimension) Observation coeff matrix
Q (np.ndarray): Shape (state_dimension, state_dimension). Process noise covariance matrix.
R (np.ndarray): Shape (obs_dimension, obs_dimension). Observation noise covariance matrix.
filter_cov (np.ndarray). Shape (state_dimension, state_dimension). Filtered covariance
filter_mean (np.ndarray). Shape (state_dimension, 1). Filter mean

Returns:
mean (np.ndarray). Shape (obs_dimension, 1)
cov (np.ndarray). Shape (obs_dimension, obs_dimension).
"""
mean = C @ A @ filter_mean
cov = C @ (A @ filter_cov @ A.T + Q) @ C.T + R
return mean, cov


def compute_marginal_nll(value, mean, covariance):
return -1 * jax.scipy.stats.multivariate_normal.logpdf(value, mean, covariance)


def parallel_loss_single(A_scan, b_scan, C_scan, A, C, Q, R, next_observation, m0):
curr_mean = get_kalman_means(A_scan, b_scan, m0)
curr_cov = get_kalman_variances(C_scan) # Placeholder; just returns identity

next_mean, next_cov = get_next_cov(A, C, Q, R, curr_cov, curr_mean)
return jnp.squeeze(curr_mean), curr_cov, compute_marginal_nll(jnp.squeeze(next_observation),
jnp.squeeze(next_mean), next_cov)


parallel_loss_func_vmap = jit(
vmap(parallel_loss_single, in_axes=(0, 0, 0, None, None, None, None, 0, None),
out_axes=(0, 0, 0)))


@partial(jit)
def y1_given_x0_nll(C, A, Q, R, m0, cov0, obs):
y1_predictive_mean = C @ A @ jnp.expand_dims(m0, axis=1)
y1_predictive_cov = C @ (A @ cov0 @ A.T + Q) @ C.T + R
addend = -1 * jax.scipy.stats.multivariate_normal.logpdf(obs, jnp.squeeze(y1_predictive_mean),
y1_predictive_cov)
return addend


def pkf_and_loss(y, m0, cov0, A, Q, C, R):
A_scan, b_scan, C_scan, _, _ = pkf_func(y, m0, cov0, A, Q, C, R)

# Gives us the NLL for p(y_i | y_1, ..., y_{i-1}) for i > 1.
# Need to use the parallel scan outputs for this. i = 1 handled below
filtered_states, filtered_covariances, losses = parallel_loss_func_vmap(A_scan[:-1],
b_scan[:-1],
C_scan[:-1], A, C, Q,
R, y[1:], m0)

# Gives us the NLL for p_y(y_1 | x_0)
addend = y1_given_x0_nll(C, A, Q, R, m0, cov0, y[0])

final_mean = get_kalman_means(A_scan[-1], b_scan[-1], m0).T
final_covariance = jnp.expand_dims(get_kalman_variances(C_scan[-1]), axis=0)
filtered_states = jnp.concatenate([filtered_states, final_mean], axis=0)
filtered_variances = jnp.concatenate([filtered_covariances, final_covariance], axis=0)
return filtered_states, filtered_variances, jnp.sum(losses) + addend


# -------------------------------------------------------------------------------------
# Misc: These miscellaneous functions generally have specific computations used by the
# core functions or the smoothers
Expand Down
1 change: 0 additions & 1 deletion eks/ibl_paw_multiview_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from eks.core import backward_pass, eks_zscore, ensemble, forward_pass
from eks.utils import make_dlc_pandas_index


# TODO:
# - allow conf_weighted_mean for ensemble variance computation

Expand Down
7 changes: 4 additions & 3 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pandas as pd
from scipy.optimize import minimize

from eks.core import backward_pass, compute_nll, eks_zscore, ensemble, forward_pass
from eks.utils import crop_frames, make_dlc_pandas_index, format_data
from eks.core import backward_pass, compute_nll, ensemble, forward_pass
from eks.utils import crop_frames, format_data, make_dlc_pandas_index


def get_pupil_location(dlc):
Expand Down Expand Up @@ -276,7 +276,8 @@ def ensemble_kalman_smoother_ibl_pupil(
# compute zscore for EKS to see how it deviates from the ensemble
# eks_predictions = \
# np.asarray([processed_arr_dict[key_pair[0]], processed_arr_dict[key_pair[1]]]).T
# ensemble_preds_curr = ensemble_preds[:, ensemble_indices[i][0]: ensemble_indices[i][1] + 1]
# ensemble_preds_curr = ensemble_preds
# [:, ensemble_indices[i][0]: ensemble_indices[i][1] + 1]
# ensemble_vars_curr = ensemble_vars[:, ensemble_indices[i][0]: ensemble_indices[i][1] + 1]
# zscore, _ = eks_zscore(
# eks_predictions,
Expand Down
Loading
Loading