Skip to content

Commit

Permalink
fix: respect aliasing rule by not reading past of reference
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi committed Jan 17, 2025
1 parent e26e5e3 commit f8a1e73
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 84 deletions.
51 changes: 24 additions & 27 deletions src/datatype/memory_pgvector_halfvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use std::ops::Deref;
use std::marker::PhantomData;
use std::ptr::NonNull;
use vector::VectorBorrowed;
use vector::vect::VectBorrowed;
Expand All @@ -28,19 +28,19 @@ impl PgvectorHalfvecHeader {
}
(size_of::<Self>() + size_of::<f16>() * len).next_multiple_of(8)
}
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
pub unsafe fn as_borrowed<'a>(this: *const Self) -> VectBorrowed<'a, f16> {
unsafe {
VectBorrowed::new_unchecked(std::slice::from_raw_parts(
self.phantom.as_ptr(),
self.dims as usize,
(&raw const (*this).phantom).cast(),
(*this).dims as usize,
))
}
}
}

pub enum PgvectorHalfvecInput<'a> {
Owned(PgvectorHalfvecOutput),
Borrowed(&'a PgvectorHalfvecHeader),
Borrowed(NonNull<PgvectorHalfvecHeader>, PhantomData<&'a ()>),
}

impl PgvectorHalfvecInput<'_> {
Expand All @@ -51,19 +51,17 @@ impl PgvectorHalfvecInput<'_> {
if p != q {
PgvectorHalfvecInput::Owned(PgvectorHalfvecOutput(q))
} else {
unsafe { PgvectorHalfvecInput::Borrowed(p.as_ref()) }
Self::Borrowed(p, PhantomData)
}
}
}

impl Deref for PgvectorHalfvecInput<'_> {
type Target = PgvectorHalfvecHeader;

fn deref(&self) -> &Self::Target {
match self {
PgvectorHalfvecInput::Owned(x) => x,
PgvectorHalfvecInput::Borrowed(x) => x,
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
let header = match self {
Self::Owned(x) => x.0,
Self::Borrowed(x, _) => *x,
}
.as_ptr()
.cast_const();
unsafe { PgvectorHalfvecHeader::as_borrowed(header) }
}
}

Expand All @@ -79,7 +77,11 @@ impl PgvectorHalfvecOutput {
(&raw mut (*ptr).varlena).write((size << 2) as u32);
(&raw mut (*ptr).dims).write(vector.dims() as _);
(&raw mut (*ptr).unused).write(0);
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
std::ptr::copy_nonoverlapping(
slice.as_ptr(),
(&raw mut (*ptr).phantom).cast(),
slice.len(),
);
PgvectorHalfvecOutput(NonNull::new(ptr).unwrap())
}
}
Expand All @@ -88,13 +90,8 @@ impl PgvectorHalfvecOutput {
std::mem::forget(self);
result
}
}

impl Deref for PgvectorHalfvecOutput {
type Target = PgvectorHalfvecHeader;

fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
unsafe { PgvectorHalfvecHeader::as_borrowed(self.0.as_ptr().cast_const()) }
}
}

Expand Down Expand Up @@ -142,8 +139,8 @@ impl FromDatum for PgvectorHalfvecOutput {
if p != q {
Some(PgvectorHalfvecOutput(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
let header = p.as_ptr().cast_const();
let vector = unsafe { PgvectorHalfvecHeader::as_borrowed(header) };
Some(PgvectorHalfvecOutput::new(vector))
}
}
Expand All @@ -164,8 +161,8 @@ unsafe impl pgrx::datum::UnboxDatum for PgvectorHalfvecOutput {
if p != q {
PgvectorHalfvecOutput(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
let header = p.as_ptr().cast_const();
let vector = unsafe { PgvectorHalfvecHeader::as_borrowed(header) };
PgvectorHalfvecOutput::new(vector)
}
}
Expand Down
51 changes: 24 additions & 27 deletions src/datatype/memory_pgvector_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use std::ops::Deref;
use std::marker::PhantomData;
use std::ptr::NonNull;
use vector::VectorBorrowed;
use vector::vect::VectBorrowed;
Expand All @@ -27,19 +27,19 @@ impl PgvectorVectorHeader {
}
(size_of::<Self>() + size_of::<f32>() * len).next_multiple_of(8)
}
pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> {
pub unsafe fn as_borrowed<'a>(this: *const Self) -> VectBorrowed<'a, f32> {
unsafe {
VectBorrowed::new_unchecked(std::slice::from_raw_parts(
self.phantom.as_ptr(),
self.dims as usize,
(&raw const (*this).phantom).cast(),
(*this).dims as usize,
))
}
}
}

pub enum PgvectorVectorInput<'a> {
Owned(PgvectorVectorOutput),
Borrowed(&'a PgvectorVectorHeader),
Borrowed(NonNull<PgvectorVectorHeader>, PhantomData<&'a ()>),
}

impl PgvectorVectorInput<'_> {
Expand All @@ -50,19 +50,17 @@ impl PgvectorVectorInput<'_> {
if p != q {
PgvectorVectorInput::Owned(PgvectorVectorOutput(q))
} else {
unsafe { PgvectorVectorInput::Borrowed(p.as_ref()) }
Self::Borrowed(p, PhantomData)
}
}
}

impl Deref for PgvectorVectorInput<'_> {
type Target = PgvectorVectorHeader;

fn deref(&self) -> &Self::Target {
match self {
PgvectorVectorInput::Owned(x) => x,
PgvectorVectorInput::Borrowed(x) => x,
pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> {
let header = match self {
Self::Owned(x) => x.0,
Self::Borrowed(x, _) => *x,
}
.as_ptr()
.cast_const();
unsafe { PgvectorVectorHeader::as_borrowed(header) }
}
}

Expand All @@ -78,7 +76,11 @@ impl PgvectorVectorOutput {
(&raw mut (*ptr).varlena).write((size << 2) as u32);
(&raw mut (*ptr).dims).write(vector.dims() as _);
(&raw mut (*ptr).unused).write(0);
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
std::ptr::copy_nonoverlapping(
slice.as_ptr(),
(&raw mut (*ptr).phantom).cast(),
slice.len(),
);
PgvectorVectorOutput(NonNull::new(ptr).unwrap())
}
}
Expand All @@ -87,13 +89,8 @@ impl PgvectorVectorOutput {
std::mem::forget(self);
result
}
}

impl Deref for PgvectorVectorOutput {
type Target = PgvectorVectorHeader;

fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> {
unsafe { PgvectorVectorHeader::as_borrowed(self.0.as_ptr().cast_const()) }
}
}

Expand Down Expand Up @@ -141,8 +138,8 @@ impl FromDatum for PgvectorVectorOutput {
if p != q {
Some(PgvectorVectorOutput(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
let header = p.as_ptr().cast_const();
let vector = unsafe { PgvectorVectorHeader::as_borrowed(header) };
Some(PgvectorVectorOutput::new(vector))
}
}
Expand All @@ -163,8 +160,8 @@ unsafe impl pgrx::datum::UnboxDatum for PgvectorVectorOutput {
if p != q {
PgvectorVectorOutput(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
let header = p.as_ptr().cast_const();
let vector = unsafe { PgvectorVectorHeader::as_borrowed(header) };
PgvectorVectorOutput::new(vector)
}
}
Expand Down
60 changes: 30 additions & 30 deletions src/datatype/memory_scalar8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use std::ops::Deref;
use std::marker::PhantomData;
use std::ptr::NonNull;
use vector::VectorBorrowed;
use vector::scalar8::Scalar8Borrowed;
Expand All @@ -31,22 +31,25 @@ impl Scalar8Header {
}
(size_of::<Self>() + size_of::<u8>() * len).next_multiple_of(8)
}
pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> {
pub unsafe fn as_borrowed<'a>(this: *const Self) -> Scalar8Borrowed<'a> {
unsafe {
Scalar8Borrowed::new_unchecked(
self.sum_of_x2,
self.k,
self.b,
self.sum_of_code,
std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize),
(*this).sum_of_x2,
(*this).k,
(*this).b,
(*this).sum_of_code,
std::slice::from_raw_parts(
(&raw const (*this).phantom).cast(),
(*this).dims as usize,
),
)
}
}
}

pub enum Scalar8Input<'a> {
Owned(Scalar8Output),
Borrowed(&'a Scalar8Header),
Borrowed(NonNull<Scalar8Header>, PhantomData<&'a ()>),
}

impl Scalar8Input<'_> {
Expand All @@ -57,19 +60,17 @@ impl Scalar8Input<'_> {
if p != q {
Scalar8Input::Owned(Scalar8Output(q))
} else {
unsafe { Scalar8Input::Borrowed(p.as_ref()) }
Self::Borrowed(p, PhantomData)
}
}
}

impl Deref for Scalar8Input<'_> {
type Target = Scalar8Header;

fn deref(&self) -> &Self::Target {
match self {
Scalar8Input::Owned(x) => x,
Scalar8Input::Borrowed(x) => x,
pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> {
let header = match self {
Self::Owned(x) => x.0,
Self::Borrowed(x, _) => *x,
}
.as_ptr()
.cast_const();
unsafe { Scalar8Header::as_borrowed(header) }
}
}

Expand All @@ -89,7 +90,11 @@ impl Scalar8Output {
(&raw mut (*ptr).k).write(vector.k());
(&raw mut (*ptr).b).write(vector.b());
(&raw mut (*ptr).sum_of_code).write(vector.sum_of_code());
std::ptr::copy_nonoverlapping(code.as_ptr(), (*ptr).phantom.as_mut_ptr(), code.len());
std::ptr::copy_nonoverlapping(
code.as_ptr(),
(&raw mut (*ptr).phantom).cast(),
code.len(),
);
Scalar8Output(NonNull::new(ptr).unwrap())
}
}
Expand All @@ -98,13 +103,8 @@ impl Scalar8Output {
std::mem::forget(self);
result
}
}

impl Deref for Scalar8Output {
type Target = Scalar8Header;

fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> {
unsafe { Scalar8Header::as_borrowed(self.0.as_ptr().cast_const()) }
}
}

Expand Down Expand Up @@ -152,8 +152,8 @@ impl FromDatum for Scalar8Output {
if p != q {
Some(Scalar8Output(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
let header = p.as_ptr().cast_const();
let vector = unsafe { Scalar8Header::as_borrowed(header) };
Some(Scalar8Output::new(vector))
}
}
Expand All @@ -174,8 +174,8 @@ unsafe impl pgrx::datum::UnboxDatum for Scalar8Output {
if p != q {
Scalar8Output(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
let header = p.as_ptr().cast_const();
let vector = unsafe { Scalar8Header::as_borrowed(header) };
Scalar8Output::new(vector)
}
}
Expand Down

0 comments on commit f8a1e73

Please sign in to comment.