Skip to content

Commit

Permalink
panic-safe delta lookback decoding (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwlon authored Dec 8, 2024
1 parent 1a7d18a commit 6abb0f8
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 35 deletions.
2 changes: 2 additions & 0 deletions docs/format.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ and the deltas `[0, 10, 0]` would decode to the latents `[1, 3, 5, 17, 29]`.
Letting `lookback` be the delta latent variable.
Mode latents are decoded via `l[i] += l[i - lookback[i]]`.

The decompressor should error if any lookback exceeds the window.

### Modes

Based on the mode, latents are joined into the finalized numbers.
Expand Down
79 changes: 62 additions & 17 deletions pco/src/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,34 +256,53 @@ pub fn new_lookback_window_buffer_and_pos<L: Latent>(
(res, window_n)
}

// returns the new position
// returns whether it was corrupt
pub fn decode_with_lookbacks_in_place<L: Latent>(
config: DeltaLookbackConfig,
lookbacks: &[DeltaLookback],
window_buffer_pos: &mut usize,
window_buffer: &mut [L],
latents: &mut [L],
) {
) -> bool {
toggle_center_in_place(latents);

let (window_n, state_n) = (config.window_n(), config.state_n());
let mut pos = *window_buffer_pos;
let mut start_pos = *window_buffer_pos;
// Lookbacks can be shorter than latents in the final batch,
// but we always decompress latents.len() numbers
let batch_n = latents.len();
if pos + batch_n > window_buffer.len() {
if start_pos + batch_n > window_buffer.len() {
// we need to cycle the buffer
for i in 0..window_n {
window_buffer[i] = window_buffer[i + pos - window_n];
}
pos = window_n;
window_buffer.copy_within(start_pos - window_n..start_pos, 0);
start_pos = window_n;
}
let mut has_oob_lookbacks = false;

for (i, (&latent, &lookback)) in latents.iter().zip(lookbacks).enumerate() {
window_buffer[pos + i] = latent.wrapping_add(window_buffer[pos + i - lookback as usize]);
let pos = start_pos + i;
// Here we return whether the data is corrupt because it's
// better than the alternatives:
// * Taking min(lookback, window_n) or modulo is just as slow but silences
// the problem.
// * Doing a checked set is slower, panics, and get doesn't catch all
// cases.
let lookback = if lookback <= window_n as DeltaLookback {
lookback as usize
} else {
has_oob_lookbacks = true;
1
};
unsafe {
*window_buffer.get_unchecked_mut(pos) =
latent.wrapping_add(*window_buffer.get_unchecked(pos - lookback));
}
}

let new_pos = pos + batch_n;
latents.copy_from_slice(&window_buffer[pos - state_n..new_pos - state_n]);
*window_buffer_pos = new_pos;
let end_pos = start_pos + batch_n;
latents.copy_from_slice(&window_buffer[start_pos - state_n..end_pos - state_n]);
*window_buffer_pos = end_pos;

has_oob_lookbacks
}

pub fn compute_delta_latent_var(
Expand Down Expand Up @@ -375,6 +394,10 @@ mod tests {
state_n_log: 1,
secondary_uses_delta: false,
};
let window_n = config.window_n();
assert_eq!(window_n, 16);
let state_n = config.state_n();
assert_eq!(state_n, 2);

let mut deltas = original_latents.clone();
let lookbacks = choose_lookbacks(config, &original_latents);
Expand All @@ -389,21 +412,43 @@ mod tests {
// Encoding left junk deltas at the front,
// but for decoding we need junk deltas at the end.
let mut deltas_to_decode = Vec::<u32>::new();
deltas_to_decode.extend(&deltas[2..]);
for _ in 0..2 {
deltas_to_decode.extend(&deltas[state_n..]);
for _ in 0..state_n {
deltas_to_decode.push(1337);
}

let (mut window_buffer, mut pos) = new_lookback_window_buffer_and_pos(config, &state);
assert_eq!(pos, 16);
decode_with_lookbacks_in_place(
assert_eq!(pos, window_n);
let has_oob_lookbacks = decode_with_lookbacks_in_place(
config,
&lookbacks,
&mut pos,
&mut window_buffer,
&mut deltas_to_decode,
);
assert!(!has_oob_lookbacks);
assert_eq!(deltas_to_decode, original_latents);
assert_eq!(pos, 16 + original_latents.len());
assert_eq!(pos, window_n + original_latents.len());
}

#[test]
fn test_corrupt_lookbacks_do_not_panic() {
let config = DeltaLookbackConfig {
state_n_log: 0,
window_n_log: 2,
secondary_uses_delta: false,
};
let delta_state = vec![0_u32];
let lookbacks = vec![5, 1, 1, 1];
let mut latents = vec![1_u32, 2, 3, 4];
let (mut window_buffer, mut pos) = new_lookback_window_buffer_and_pos(config, &delta_state);
let has_oob_lookbacks = decode_with_lookbacks_in_place(
config,
&lookbacks,
&mut pos,
&mut window_buffer,
&mut latents,
);
assert!(has_oob_lookbacks);
}
}
18 changes: 13 additions & 5 deletions pco/src/latent_page_decompressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::ans::{AnsState, Spec};
use crate::bit_reader::BitReader;
use crate::constants::{Bitlen, DeltaLookback, ANS_INTERLEAVING, FULL_BATCH_N};
use crate::data_types::Latent;
use crate::errors::PcoResult;
use crate::errors::{PcoError, PcoResult};
use crate::metadata::{bins, Bin, DeltaEncoding, DynLatents};
use crate::{ans, bit_reader, delta, read_write_uint};

Expand Down Expand Up @@ -282,7 +282,7 @@ impl<L: Latent> LatentPageDecompressor<L> {
n_remaining_in_page: usize,
reader: &mut BitReader,
dst: &mut [L],
) {
) -> PcoResult<()> {
let n_remaining_pre_delta =
n_remaining_in_page.saturating_sub(self.delta_encoding.n_latents_per_state());
let pre_delta_len = if dst.len() <= n_remaining_pre_delta {
Expand All @@ -297,12 +297,13 @@ impl<L: Latent> LatentPageDecompressor<L> {
self.decompress_batch_pre_delta(reader, &mut dst[..pre_delta_len]);

match self.delta_encoding {
DeltaEncoding::None => (),
DeltaEncoding::None => Ok(()),
DeltaEncoding::Consecutive(_) => {
delta::decode_consecutive_in_place(&mut self.state.delta_state, dst)
delta::decode_consecutive_in_place(&mut self.state.delta_state, dst);
Ok(())
}
DeltaEncoding::Lookback(config) => {
delta::decode_with_lookbacks_in_place(
let has_oob_lookbacks = delta::decode_with_lookbacks_in_place(
config,
delta_latents
.unwrap()
Expand All @@ -312,6 +313,13 @@ impl<L: Latent> LatentPageDecompressor<L> {
&mut self.state.delta_state,
dst,
);
if has_oob_lookbacks {
Err(PcoError::corruption(
"delta lookback exceeded window n",
))
} else {
Ok(())
}
}
}
}
Expand Down
26 changes: 25 additions & 1 deletion pco/src/metadata/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use better_io::BetterBufRead;

use crate::bit_reader::BitReaderBuilder;
use crate::bit_writer::BitWriter;
use crate::constants::DeltaLookback;
use crate::data_types::LatentType;
use crate::errors::PcoResult;
use crate::errors::{PcoError, PcoResult};
use crate::metadata::chunk_latent_var::ChunkLatentVarMeta;
use crate::metadata::delta_encoding::DeltaEncoding;
use crate::metadata::format_version::FormatVersion;
Expand Down Expand Up @@ -50,6 +51,29 @@ impl ChunkMeta {
bit_size.div_ceil(8)
}

pub(crate) fn validate_delta_encoding(&self) -> PcoResult<()> {
let delta_latent_var = &self.per_latent_var.delta;
match (self.delta_encoding, delta_latent_var) {
(DeltaEncoding::Lookback(config), Some(latent_var)) => {
let window_n = config.window_n() as DeltaLookback;
let bins = latent_var.bins.downcast_ref::<DeltaLookback>().unwrap();
let maybe_corrupt_bin = bins
.iter()
.find(|bin| bin.lower < 1 || bin.lower > window_n);
if let Some(corrupt_bin) = maybe_corrupt_bin {
Err(PcoError::corruption(format!(
"delta lookback bin had invalid lower bound of {} outside window [1, {}]",
corrupt_bin.lower, window_n
)))
} else {
Ok(())
}
}
(DeltaEncoding::None, None) | (DeltaEncoding::Consecutive(_), None) => Ok(()),
_ => unreachable!(),
}
}

pub(crate) unsafe fn read_from<R: BetterBufRead>(
reader_builder: &mut BitReaderBuilder<R>,
version: &FormatVersion,
Expand Down
17 changes: 9 additions & 8 deletions pco/src/wrapped/chunk_decompressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ pub struct ChunkDecompressor<T: Number> {

impl<T: Number> ChunkDecompressor<T> {
pub(crate) fn new(meta: ChunkMeta) -> PcoResult<Self> {
if T::mode_is_valid(meta.mode) {
Ok(Self {
meta,
phantom: PhantomData,
})
} else {
Err(PcoError::corruption(format!(
if !T::mode_is_valid(meta.mode) {
return Err(PcoError::corruption(format!(
"invalid mode for data type: {:?}",
meta.mode
)))
)));
}
meta.validate_delta_encoding()?;

Ok(Self {
meta,
phantom: PhantomData,
})
}

/// Returns pre-computed information about the chunk.
Expand Down
6 changes: 2 additions & 4 deletions pco/src/wrapped/page_decompressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ impl<T: Number, R: BetterBufRead> PageDecompressor<T, R> {
n_remaining,
reader,
primary_dst,
);
Ok(())
)
})?;

// SECONDARY LATENTS
Expand All @@ -226,8 +225,7 @@ impl<T: Number, R: BetterBufRead> PageDecompressor<T, R> {
&mut dst.downcast_mut::<L>().unwrap()[..batch_n]
)
}
);
Ok(())
)
})?;
}

Expand Down

0 comments on commit 6abb0f8

Please sign in to comment.