diff --git a/sqlx-postgres/src/bind_iter.rs b/sqlx-postgres/src/bind_iter.rs new file mode 100644 index 0000000000..0f44f19e3d --- /dev/null +++ b/sqlx-postgres/src/bind_iter.rs @@ -0,0 +1,154 @@ +use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres}; +use core::cell::Cell; +use sqlx_core::{ + database::Database, + encode::{Encode, IsNull}, + error::BoxDynError, + types::Type, +}; + +// not exported but pub because it is used in the extension trait +pub struct PgBindIter(Cell>); + +/// Iterator extension trait enabling iterators to encode arrays in Postgres. +/// +/// Because of the blanket impl of `PgHasArrayType` for all references +/// we can borrow instead of needing to clone or copy in the iterators +/// and it still works +/// +/// Previously, 3 separate arrays would be needed in this example which +/// requires iterating 3 times to collect items into the array and then +/// iterating over them again to encode. +/// +/// This now requires only iterating over the array once for each field +/// while using less memory giving both speed and memory usage improvements +/// along with allowing much more flexibility in the underlying collection. +/// +/// ```rust,no_run +/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> { +/// # use sqlx::types::chrono::{DateTime, Utc}; +/// # use sqlx::Connection; +/// # fn people() -> &'static [Person] { +/// # &[] +/// # } +/// # let mut conn = ::Connection::connect("dummyurl").await?; +/// use sqlx::postgres::PgBindIterExt; +/// +/// #[derive(sqlx::FromRow)] +/// struct Person { +/// id: i64, +/// name: String, +/// birthdate: DateTime, +/// } +/// +/// # let people: &[Person] = people(); +/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)") +/// .bind(people.iter().map(|p| p.id).bind_iter()) +/// .bind(people.iter().map(|p| &p.name).bind_iter()) +/// .bind(people.iter().map(|p| &p.birthdate).bind_iter()) +/// .execute(&mut conn) +/// .await?; +/// +/// # Ok(()) +/// # } +/// ``` +pub trait PgBindIterExt: Iterator + Sized { + fn bind_iter(self) -> PgBindIter; +} + +impl PgBindIterExt for I { + fn bind_iter(self) -> PgBindIter { + PgBindIter(Cell::new(Some(self))) + } +} + +impl Type for PgBindIter +where + I: Iterator, + ::Item: Type + PgHasArrayType, +{ + fn type_info() -> ::TypeInfo { + ::Item::array_type_info() + } + fn compatible(ty: &PgTypeInfo) -> bool { + ::Item::array_compatible(ty) + } +} + +impl<'q, I> PgBindIter +where + I: Iterator, + ::Item: Type + Encode<'q, Postgres>, +{ + fn encode_inner( + // need ownership to iterate + mut iter: I, + buf: &mut PgArgumentBuffer, + ) -> Result { + let lower_size_hint = iter.size_hint().0; + let first = iter.next(); + let type_info = first + .as_ref() + .and_then(Encode::produces) + .unwrap_or_else(::Item::type_info); + + buf.extend(&1_i32.to_be_bytes()); // number of dimensions + buf.extend(&0_i32.to_be_bytes()); // flags + + match type_info.0 { + PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), + PgType::DeclareArrayOf(array) => buf.patch_array_type(array), + + ty => { + buf.extend(&ty.oid().0.to_be_bytes()); + } + } + + let len_start = buf.len(); + buf.extend(0_i32.to_be_bytes()); // len (unknown so far) + buf.extend(1_i32.to_be_bytes()); // lower bound + + match first { + Some(first) => buf.encode(first)?, + None => return Ok(IsNull::No), + } + + let mut count = 1_i32; + const MAX: usize = i32::MAX as usize - 1; + + for value in (&mut iter).take(MAX) { + buf.encode(value)?; + count += 1; + } + + const OVERFLOW: usize = i32::MAX as usize + 1; + if iter.next().is_some() { + let iter_size = std::cmp::max(lower_size_hint, OVERFLOW); + return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into()); + } + + // set the length now that we know what it is. + buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes()); + + Ok(IsNull::No) + } +} + +impl<'q, I> Encode<'q, Postgres> for PgBindIter +where + I: Iterator, + ::Item: Type + Encode<'q, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf) + } + fn encode(self, buf: &mut PgArgumentBuffer) -> Result + where + Self: Sized, + { + Self::encode_inner( + self.0.into_inner().expect("PgBindIter is only used once"), + buf, + ) + } +} diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index c50f53067e..76b0eb3206 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -7,6 +7,7 @@ use crate::executor::Executor; mod advisory_lock; mod arguments; +mod bind_iter; mod column; mod connection; mod copy; @@ -44,6 +45,7 @@ pub(crate) use sqlx_core::driver_prelude::*; pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey}; pub use arguments::{PgArgumentBuffer, PgArguments}; +pub use bind_iter::PgBindIterExt; pub use column::PgColumn; pub use connection::PgConnection; pub use copy::{PgCopyIn, PgPoolCopyExt}; diff --git a/sqlx-postgres/src/types/array.rs b/sqlx-postgres/src/types/array.rs index 9b8be63412..372c2891a8 100644 --- a/sqlx-postgres/src/types/array.rs +++ b/sqlx-postgres/src/types/array.rs @@ -5,7 +5,6 @@ use std::borrow::Cow; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; -use crate::type_info::PgType; use crate::types::Oid; use crate::types::Type; use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; @@ -156,39 +155,14 @@ where T: Encode<'q, Postgres> + Type, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - let type_info = self - .first() - .and_then(Encode::produces) - .unwrap_or_else(T::type_info); - - buf.extend(&1_i32.to_be_bytes()); // number of dimensions - buf.extend(&0_i32.to_be_bytes()); // flags - - // element type - match type_info.0 { - PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), - PgType::DeclareArrayOf(array) => buf.patch_array_type(array), - - ty => { - buf.extend(&ty.oid().0.to_be_bytes()); - } - } - - let array_len = i32::try_from(self.len()).map_err(|_| { + // do the length check early to avoid doing unnecessary work + i32::try_from(self.len()).map_err(|_| { format!( "encoded array length is too large for Postgres: {}", self.len() ) })?; - - buf.extend(array_len.to_be_bytes()); // len - buf.extend(&1_i32.to_be_bytes()); // lower bound - - for element in self.iter() { - buf.encode(element)?; - } - - Ok(IsNull::No) + crate::PgBindIterExt::bind_iter(self.iter()).encode(buf) } } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 87a18db510..e1c2f086d5 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -2042,3 +2042,61 @@ async fn test_issue_3052() { "expected encode error, got {too_large_error:?}", ); } + +#[sqlx_macros::test] +async fn test_bind_iter() -> anyhow::Result<()> { + use sqlx::postgres::PgBindIterExt; + use sqlx::types::chrono::{DateTime, Utc}; + + let mut conn = new::().await?; + + #[derive(sqlx::FromRow, PartialEq, Debug)] + struct Person { + id: i64, + name: String, + birthdate: DateTime, + } + + let people: Vec = vec![ + Person { + id: 1, + name: "Alice".into(), + birthdate: "1984-01-01T00:00:00Z".parse().unwrap(), + }, + Person { + id: 2, + name: "Bob".into(), + birthdate: "2000-01-01T00:00:00Z".parse().unwrap(), + }, + ]; + + sqlx::query( + r#" +create temporary table person( + id int8 primary key, + name text not null, + birthdate timestamptz not null +)"#, + ) + .execute(&mut conn) + .await?; + + let rows_affected = + sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)") + // owned value + .bind(people.iter().map(|p| p.id).bind_iter()) + // borrowed value + .bind(people.iter().map(|p| &p.name).bind_iter()) + .bind(people.iter().map(|p| &p.birthdate).bind_iter()) + .execute(&mut conn) + .await? + .rows_affected(); + assert_eq!(rows_affected, 2); + + let p_query = sqlx::query_as::<_, Person>("select * from person order by id") + .fetch_all(&mut conn) + .await?; + + assert_eq!(people, p_query); + Ok(()) +}