Skip to content

Commit

Permalink
fix: respect aliasing rule by not reading past of reference
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Jan 20, 2025
1 parent 6af8bdd commit 62aa62b
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 213 deletions.
145 changes: 75 additions & 70 deletions src/datatype/memory_pgvector_halfvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use std::ops::Deref;
use std::marker::PhantomData;
use std::ptr::NonNull;
use vector::VectorBorrowed;
use vector::vect::VectBorrowed;
Expand All @@ -18,7 +18,7 @@ pub struct PgvectorHalfvecHeader {
varlena: u32,
dims: u16,
unused: u16,
phantom: [f16; 0],
elements: [f16; 0],
}

impl PgvectorHalfvecHeader {
Expand All @@ -28,49 +28,52 @@ impl PgvectorHalfvecHeader {
}
(size_of::<Self>() + size_of::<f16>() * len).next_multiple_of(8)
}
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
pub unsafe fn as_borrowed<'a>(this: NonNull<Self>) -> VectBorrowed<'a, f16> {
unsafe {
let this = this.as_ptr();
VectBorrowed::new_unchecked(std::slice::from_raw_parts(
self.phantom.as_ptr(),
self.dims as usize,
(&raw const (*this).elements).cast(),
(&raw const (*this).dims).read() as usize,
))
}
}
}

pub enum PgvectorHalfvecInput<'a> {
Owned(PgvectorHalfvecOutput),
Borrowed(&'a PgvectorHalfvecHeader),
}
pub struct PgvectorHalfvecInput<'a>(NonNull<PgvectorHalfvecHeader>, PhantomData<&'a ()>, bool);

impl PgvectorHalfvecInput<'_> {
unsafe fn new(p: NonNull<PgvectorHalfvecHeader>) -> Self {
unsafe fn from_ptr(p: NonNull<PgvectorHalfvecHeader>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap()
};
if p != q {
PgvectorHalfvecInput::Owned(PgvectorHalfvecOutput(q))
} else {
unsafe { PgvectorHalfvecInput::Borrowed(p.as_ref()) }
}
PgvectorHalfvecInput(q, PhantomData, p != q)
}
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
unsafe { PgvectorHalfvecHeader::as_borrowed(self.0) }
}
}

impl Deref for PgvectorHalfvecInput<'_> {
type Target = PgvectorHalfvecHeader;

fn deref(&self) -> &Self::Target {
match self {
PgvectorHalfvecInput::Owned(x) => x,
PgvectorHalfvecInput::Borrowed(x) => x,
impl Drop for PgvectorHalfvecInput<'_> {
fn drop(&mut self) {
if self.2 {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr().cast());
}
}
}
}

pub struct PgvectorHalfvecOutput(NonNull<PgvectorHalfvecHeader>);

impl PgvectorHalfvecOutput {
pub fn new(vector: VectBorrowed<'_, f16>) -> PgvectorHalfvecOutput {
unsafe fn from_ptr(p: NonNull<PgvectorHalfvecHeader>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum_copy(p.as_ptr().cast()).cast()).unwrap()
};
Self(q)
}
#[allow(dead_code)]
pub fn new(vector: VectBorrowed<'_, f16>) -> Self {
unsafe {
let slice = vector.slice();
let size = PgvectorHalfvecHeader::size_of(slice.len());
Expand All @@ -79,47 +82,61 @@ impl PgvectorHalfvecOutput {
(&raw mut (*ptr).varlena).write((size << 2) as u32);
(&raw mut (*ptr).dims).write(vector.dims() as _);
(&raw mut (*ptr).unused).write(0);
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
PgvectorHalfvecOutput(NonNull::new(ptr).unwrap())
std::ptr::copy_nonoverlapping(
slice.as_ptr(),
(&raw mut (*ptr).elements).cast(),
slice.len(),
);
Self(NonNull::new(ptr).unwrap())
}
}
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
unsafe { PgvectorHalfvecHeader::as_borrowed(self.0) }
}
pub fn into_raw(self) -> *mut PgvectorHalfvecHeader {
let result = self.0.as_ptr();
std::mem::forget(self);
result
}
}

impl Deref for PgvectorHalfvecOutput {
type Target = PgvectorHalfvecHeader;

fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
}
}

impl Drop for PgvectorHalfvecOutput {
fn drop(&mut self) {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr() as _);
pgrx::pg_sys::pfree(self.0.as_ptr().cast());
}
}
}

// FromDatum

impl FromDatum for PgvectorHalfvecInput<'_> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr::<PgvectorHalfvecHeader>()).unwrap();
unsafe { Some(PgvectorHalfvecInput::new(ptr)) }
let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap();
unsafe { Some(Self::from_ptr(ptr)) }
}
}
}

impl FromDatum for PgvectorHalfvecOutput {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap();
unsafe { Some(Self::from_ptr(ptr)) }
}
}
}

// IntoDatum

impl IntoDatum for PgvectorHalfvecOutput {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw() as *mut ()))
Some(Datum::from(self.into_raw()))
}

fn type_oid() -> Oid {
Expand All @@ -131,46 +148,23 @@ impl IntoDatum for PgvectorHalfvecOutput {
}
}

impl FromDatum for PgvectorHalfvecOutput {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let p = NonNull::new(datum.cast_mut_ptr::<PgvectorHalfvecHeader>())?;
let q =
unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? };
if p != q {
Some(PgvectorHalfvecOutput(q))
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
Some(PgvectorHalfvecOutput::new(vector))
}
}
}
}
// UnboxDatum

unsafe impl pgrx::datum::UnboxDatum for PgvectorHalfvecOutput {
type As<'src> = PgvectorHalfvecOutput;
#[inline]
unsafe fn unbox<'src>(d: pgrx::datum::Datum<'src>) -> Self::As<'src>
unsafe fn unbox<'src>(datum: pgrx::datum::Datum<'src>) -> Self::As<'src>
where
Self: 'src,
{
let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::<PgvectorHalfvecHeader>()).unwrap();
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
PgvectorHalfvecOutput(q)
} else {
let header = p.as_ptr();
let vector = unsafe { (*header).as_borrowed() };
PgvectorHalfvecOutput::new(vector)
}
let datum = datum.sans_lifetime();
let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap();
unsafe { Self::from_ptr(ptr) }
}
}

// SqlTranslatable

unsafe impl SqlTranslatable for PgvectorHalfvecInput<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("halfvec")))
Expand All @@ -189,17 +183,28 @@ unsafe impl SqlTranslatable for PgvectorHalfvecOutput {
}
}

// ArgAbi

unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for PgvectorHalfvecInput<'fcx> {
unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self {
unsafe { arg.unbox_arg_using_from_datum().unwrap() }
let index = arg.index();
unsafe {
arg.unbox_arg_using_from_datum()
.unwrap_or_else(|| panic!("argument {index} must not be null"))
}
}
}

// BoxRet

unsafe impl pgrx::callconv::BoxRet for PgvectorHalfvecOutput {
unsafe fn box_into<'fcx>(
self,
fcinfo: &mut pgrx::callconv::FcInfo<'fcx>,
) -> pgrx::datum::Datum<'fcx> {
unsafe { fcinfo.return_raw_datum(Datum::from(self.into_raw() as *mut ())) }
match self.into_datum() {
Some(datum) => unsafe { fcinfo.return_raw_datum(datum) },
None => fcinfo.return_null(),
}
}
}
Loading

0 comments on commit 62aa62b

Please sign in to comment.