From 6887da74a3311a858b6b74f58a1fba5ab0b19891 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 14:51:00 -0800 Subject: [PATCH 01/13] chore: update Cargo.lock These updates are forced as soon as I build. --- Cargo.lock | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 10e24131bd..04969c2af2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1177,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -1914,7 +1914,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -3972,7 +3972,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -4801,7 +4801,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4842,7 +4842,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4885,5 +4885,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] From 494dc12dde5dff58365c86d43446d99fd2f6c96a Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 14:35:54 -0800 Subject: [PATCH 02/13] feat: add Connection::begin_with This patch completes the plumbing of an optional statement from these methods to `TransactionManager::begin` without any validation of the provided statement. There is a new `Error::InvalidSavePoint` which is triggered by any attempt to call `Connection::begin_with` when we are already inside of a transaction. --- sqlx-core/src/acquire.rs | 4 ++-- sqlx-core/src/any/connection/backend.rs | 9 ++++++++- sqlx-core/src/any/connection/mod.rs | 13 ++++++++++++- sqlx-core/src/any/transaction.rs | 8 ++++++-- sqlx-core/src/connection.rs | 11 +++++++++++ sqlx-core/src/error.rs | 3 +++ sqlx-core/src/pool/connection.rs | 2 +- sqlx-core/src/pool/mod.rs | 8 ++++++-- sqlx-core/src/transaction.rs | 18 ++++++++++++----- sqlx-mysql/src/any.rs | 8 ++++++-- sqlx-mysql/src/connection/mod.rs | 13 ++++++++++++- sqlx-mysql/src/transaction.rs | 16 +++++++++++---- sqlx-postgres/src/any.rs | 8 ++++++-- sqlx-postgres/src/connection/mod.rs | 13 ++++++++++++- sqlx-postgres/src/transaction.rs | 15 +++++++++++--- sqlx-sqlite/src/any.rs | 9 +++++++-- sqlx-sqlite/src/connection/mod.rs | 13 ++++++++++++- sqlx-sqlite/src/connection/worker.rs | 26 +++++++++++++++++++++---- sqlx-sqlite/src/transaction.rs | 8 ++++++-- 19 files changed, 169 insertions(+), 36 deletions(-) diff --git a/sqlx-core/src/acquire.rs b/sqlx-core/src/acquire.rs index c9d7fb215c..59bac9fa59 100644 --- a/sqlx-core/src/acquire.rs +++ b/sqlx-core/src/acquire.rs @@ -93,7 +93,7 @@ impl<'a, DB: Database> Acquire<'a> for &'_ Pool { let conn = self.acquire(); Box::pin(async move { - Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?)).await + Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?), None).await }) } } @@ -121,7 +121,7 @@ macro_rules! impl_acquire { 'c, Result<$crate::transaction::Transaction<'c, $DB>, $crate::error::Error>, > { - $crate::transaction::Transaction::begin(self) + $crate::transaction::Transaction::begin(self, None) } } }; diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f3..2fe9ed7656 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -3,6 +3,7 @@ use crate::describe::Describe; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; +use std::borrow::Cow; use std::fmt::Debug; pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { @@ -26,7 +27,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin(&mut self) -> BoxFuture<'_, crate::Result<()>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin(&mut self, statement: Option>) -> BoxFuture<'_, crate::Result<()>>; fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>; diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848a..8cf8fc510c 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnectOptions}; use crate::connection::{ConnectOptions, Connection}; @@ -87,7 +88,17 @@ impl Connection for AnyConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce4175626..4972268499 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,4 +1,5 @@ use futures_util::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnection}; use crate::error::Error; @@ -9,8 +10,11 @@ pub struct AnyTransactionManager; impl TransactionManager for AnyTransactionManager { type Database = Any; - fn begin(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { - conn.backend.begin() + fn begin<'conn>( + conn: &'conn mut AnyConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + conn.backend.begin(statement) } fn commit(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c629..de0a05799d 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -4,6 +4,7 @@ use crate::error::Error; use crate::transaction::Transaction; use futures_core::future::BoxFuture; use log::LevelFilter; +use std::borrow::Cow; use std::fmt::Debug; use std::str::FromStr; use std::time::Duration; @@ -49,6 +50,16 @@ pub trait Connection: Send { where Self: Sized; + /// Begin a new transaction with a custom statement. + /// + /// Returns a [`Transaction`] for controlling and tracking the new transaction. + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized; + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 17774addd2..8b454575e9 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -111,6 +111,9 @@ pub enum Error { #[cfg(feature = "migrate")] #[error("{0}")] Migrate(#[source] Box), + + #[error("attempted to call begin_with at non-zero transaction depth")] + InvalidSavePointStatement, } impl StdError for Box {} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index bf3a6d4b1c..c029fec6eb 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -191,7 +191,7 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection futures_core::future::BoxFuture<'c, Result, Error>> { - crate::transaction::Transaction::begin(&mut **self) + crate::transaction::Transaction::begin(&mut **self, None) } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index e998618413..438eebf6c1 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -367,13 +367,17 @@ impl Pool { /// Retrieves a connection and immediately begins a new transaction. pub async fn begin(&self) -> Result, Error> { - Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + None, + ) + .await } /// Attempts to retrieve a connection and immediately begins a new transaction if successful. pub async fn try_begin(&self) -> Result>, Error> { match self.try_acquire() { - Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn)) + Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn), None) .await .map(Some), diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3a..d9459c53d4 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -16,9 +16,16 @@ pub trait TransactionManager { type Database: Database; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin( - conn: &mut ::Connection, - ) -> BoxFuture<'_, Result<(), Error>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin<'conn>( + conn: &'conn mut ::Connection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>>; /// Commit the active transaction or release the most recent savepoint. fn commit( @@ -83,11 +90,12 @@ where #[doc(hidden)] pub fn begin( conn: impl Into>, + statement: Option>, ) -> BoxFuture<'c, Result> { let mut conn = conn.into(); Box::pin(async move { - DB::TransactionManager::begin(&mut conn).await?; + DB::TransactionManager::begin(&mut conn, statement).await?; Ok(Self { connection: conn, @@ -237,7 +245,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<' #[inline] fn begin(self) -> BoxFuture<'t, Result, Error>> { - Transaction::begin(&mut **self) + Transaction::begin(&mut **self, None) } } diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 0466bfc0a4..96190f0bd2 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -16,6 +16,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use std::borrow::Cow; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -37,8 +38,11 @@ impl AnyConnectionBackend for MySqlConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - MySqlTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + MySqlTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a7701..e2c671046d 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use futures_core::future::BoxFuture; @@ -111,7 +112,17 @@ impl Connection for MySqlConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index d8538cc2b3..f287c4a80b 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use futures_core::future::BoxFuture; use crate::connection::Waiting; @@ -14,12 +16,18 @@ pub struct MySqlTransactionManager; impl TransactionManager for MySqlTransactionManager { type Database = MySql; - fn begin(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut MySqlConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - - conn.execute(&*begin_ansi_transaction_sql(depth)).await?; - conn.inner.transaction_depth = depth + 1; + if statement.is_some() && depth > 0 { + return Err(Error::InvalidSavePointStatement); + } + let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + conn.execute(&*statement).await?; + conn.inner.transaction_depth += 1; Ok(()) }) diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index efa9a044bc..d189301c13 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,6 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use std::borrow::Cow; use std::future; use sqlx_core::any::{ @@ -39,8 +40,11 @@ impl AnyConnectionBackend for PgConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - PgTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index c139f8e53d..04b9a4c9e2 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -179,7 +180,17 @@ impl Connection for PgConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index e7c78488eb..767d83c52e 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::error::Error; use crate::executor::Executor; @@ -13,11 +14,19 @@ pub struct PgTransactionManager; impl TransactionManager for PgTransactionManager { type Database = Postgres; - fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut PgConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { + let depth = conn.inner.transaction_depth; + if statement.is_some() && depth > 0 { + return Err(Error::InvalidSavePointStatement); + } + let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let rollback = Rollback::new(conn); - let query = begin_ansi_transaction_sql(rollback.conn.inner.transaction_depth); - rollback.conn.queue_simple_query(&query)?; + rollback.conn.queue_simple_query(&statement)?; rollback.conn.inner.transaction_depth += 1; rollback.conn.wait_until_ready().await?; rollback.defuse(); diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 01600d9931..2c74c01494 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::{ Either, Sqlite, SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnectOptions, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, @@ -37,8 +39,11 @@ impl AnyConnectionBackend for SqliteConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - SqliteTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + SqliteTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index a579b8a605..57194b5ef1 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::cmp::Ordering; use std::ffi::CStr; use std::fmt::Write; @@ -235,7 +236,17 @@ impl Connection for SqliteConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index a01de2419c..ff908001aa 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -56,6 +56,7 @@ enum Command { }, Begin { tx: rendezvous_oneshot::Sender>, + statement: Option>, }, Commit { tx: rendezvous_oneshot::Sender>, @@ -180,11 +181,25 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } - Command::Begin { tx } => { + Command::Begin { tx, statement } => { let depth = conn.transaction_depth; + + let statement = if depth == 0 { + statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)) + } else { + if statement.is_some() { + if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { + break; + } + continue; + } + + begin_ansi_transaction_sql(depth) + }; + let res = conn.handle - .exec(begin_ansi_transaction_sql(depth)) + .exec(statement) .map(|_| { conn.transaction_depth += 1; }); @@ -331,8 +346,11 @@ impl ConnectionWorker { Ok(rx) } - pub(crate) async fn begin(&mut self) -> Result<(), Error> { - self.oneshot_cmd_with_ack(|tx| Command::Begin { tx }) + pub(crate) async fn begin( + &mut self, + statement: Option>, + ) -> Result<(), Error> { + self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement }) .await? } diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b1..d7c40d4956 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::{Sqlite, SqliteConnection}; use sqlx_core::error::Error; @@ -10,8 +11,11 @@ pub struct SqliteTransactionManager; impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; - fn begin(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(conn.worker.begin()) + fn begin<'conn>( + conn: &'conn mut SqliteConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + Box::pin(conn.worker.begin(statement)) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { From 1e24ec062f3de3a66d865851191a766b7020e0aa Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 17:21:53 -0800 Subject: [PATCH 03/13] feat: add Pool::begin_with and Pool::try_begin_with --- sqlx-core/src/pool/mod.rs | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 438eebf6c1..b759bacdda 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -54,6 +54,7 @@ //! [`Pool::acquire`] or //! [`Pool::begin`]. +use std::borrow::Cow; use std::fmt; use std::future::Future; use std::pin::Pin; @@ -385,6 +386,36 @@ impl Pool { } } + /// Retrieves a connection and immediately begins a new transaction using `statement`. + pub async fn begin_with( + &self, + statement: impl Into>, + ) -> Result, Error> { + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + Some(statement.into()), + ) + .await + } + + /// Attempts to retrieve a connection and, if successful, immediately begins a new + /// transaction using `statement`. + pub async fn try_begin_with( + &self, + statement: impl Into>, + ) -> Result>, Error> { + match self.try_acquire() { + Some(conn) => Transaction::begin( + MaybePoolConnection::PoolConnection(conn), + Some(statement.into()), + ) + .await + .map(Some), + + None => Ok(None), + } + } + /// Shut down the connection pool, immediately waking all tasks waiting for a connection. /// /// Upon calling this method, any currently waiting or subsequent calls to [`Pool::acquire`] and From a89267bce57861425564d268ba3e4e8612cb9998 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:01:57 -0800 Subject: [PATCH 04/13] feat: add Error::BeginFailed and validate that custom "begin" statements are successful --- sqlx-core/src/error.rs | 3 +++ sqlx-mysql/src/connection/establish.rs | 1 + sqlx-mysql/src/connection/mod.rs | 10 ++++++++++ sqlx-mysql/src/protocol/response/status.rs | 2 +- sqlx-mysql/src/transaction.rs | 3 +++ sqlx-postgres/src/connection/mod.rs | 7 +++++++ sqlx-postgres/src/transaction.rs | 5 ++++- sqlx-sqlite/src/connection/mod.rs | 9 +++++++-- sqlx-sqlite/src/transaction.rs | 13 ++++++++++++- 9 files changed, 48 insertions(+), 5 deletions(-) diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 8b454575e9..150d643180 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -114,6 +114,9 @@ pub enum Error { #[error("attempted to call begin_with at non-zero transaction depth")] InvalidSavePointStatement, + + #[error("got unexpected connection status after attempting to begin transaction")] + BeginFailed, } impl StdError for Box {} diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 468478e550..f52756d4c1 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -28,6 +28,7 @@ impl MySqlConnection { inner: Box::new(MySqlConnectionInner { stream, transaction_depth: 0, + status_flags: Default::default(), cache_statement: StatementCache::new(options.statement_cache_capacity), log_settings: options.log_settings.clone(), }), diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index e2c671046d..0a2f5fb839 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -8,6 +8,7 @@ pub(crate) use stream::{MySqlStream, Waiting}; use crate::common::StatementCache; use crate::error::Error; +use crate::protocol::response::Status; use crate::protocol::statement::StmtClose; use crate::protocol::text::{Ping, Quit}; use crate::statement::MySqlStatementMetadata; @@ -35,6 +36,7 @@ pub(crate) struct MySqlConnectionInner { // transaction status pub(crate) transaction_depth: usize, + status_flags: Status, // cache by query string to the statement id and metadata cache_statement: StatementCache<(u32, MySqlStatementMetadata)>, @@ -42,6 +44,14 @@ pub(crate) struct MySqlConnectionInner { log_settings: LogSettings, } +impl MySqlConnection { + pub(crate) fn in_transaction(&self) -> bool { + self.inner + .status_flags + .intersects(Status::SERVER_STATUS_IN_TRANS) + } +} + impl Debug for MySqlConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnection").finish() diff --git a/sqlx-mysql/src/protocol/response/status.rs b/sqlx-mysql/src/protocol/response/status.rs index bf5013deed..4a8bb0375a 100644 --- a/sqlx-mysql/src/protocol/response/status.rs +++ b/sqlx-mysql/src/protocol/response/status.rs @@ -1,7 +1,7 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad // https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status bitflags::bitflags! { - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct Status: u16 { // Is raised when a multi-statement transaction has been started, either explicitly, // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index f287c4a80b..953735bf9a 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -27,6 +27,9 @@ impl TransactionManager for MySqlTransactionManager { } let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); conn.execute(&*statement).await?; + if !conn.in_transaction() { + return Err(Error::BeginFailed); + } conn.inner.transaction_depth += 1; Ok(()) diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 04b9a4c9e2..96e3e2fe12 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -128,6 +128,13 @@ impl PgConnection { Ok(()) } + + pub(crate) fn in_transaction(&self) -> bool { + match self.inner.transaction_status { + TransactionStatus::Transaction => true, + TransactionStatus::Error | TransactionStatus::Idle => false, + } + } } impl Debug for PgConnection { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 767d83c52e..ec01129d6f 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -27,8 +27,11 @@ impl TransactionManager for PgTransactionManager { let rollback = Rollback::new(conn); rollback.conn.queue_simple_query(&statement)?; - rollback.conn.inner.transaction_depth += 1; rollback.conn.wait_until_ready().await?; + if !rollback.conn.in_transaction() { + return Err(Error::BeginFailed); + } + rollback.conn.inner.transaction_depth += 1; rollback.defuse(); Ok(()) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 57194b5ef1..53c3156e9d 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -12,8 +12,8 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, - sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler, + sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; pub(crate) use handle::ConnectionHandle; @@ -503,6 +503,11 @@ impl LockedSqliteHandle<'_> { pub fn remove_rollback_hook(&mut self) { self.guard.remove_rollback_hook(); } + + pub(crate) fn in_transaction(&mut self) -> bool { + let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; + ret == 0 + } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index d7c40d4956..d217cffd61 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -15,7 +15,18 @@ impl TransactionManager for SqliteTransactionManager { conn: &'conn mut SqliteConnection, statement: Option>, ) -> BoxFuture<'conn, Result<(), Error>> { - Box::pin(conn.worker.begin(statement)) + Box::pin(async { + let is_custom_statement = statement.is_some(); + conn.worker.begin(statement).await?; + if is_custom_statement { + // Check that custom statement actually put the connection into a transaction. + let mut handle = conn.lock_handle().await?; + if !handle.in_transaction() { + return Err(Error::BeginFailed); + } + } + Ok(()) + }) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { From 69638a66ea3a9d49b3660bba51810f08f30f74e1 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:34:41 -0800 Subject: [PATCH 05/13] chore: add tests of Error::BeginFailed --- tests/mysql/error.rs | 14 +++++++++++++- tests/postgres/error.rs | 14 +++++++++++++- tests/sqlite/error.rs | 14 +++++++++++++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index 7c84266c32..090cbe1980 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, mysql::MySql, Connection}; +use sqlx::{error::ErrorKind, mysql::MySql, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,15 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} diff --git a/tests/postgres/error.rs b/tests/postgres/error.rs index d6f78140da..5e52155f33 100644 --- a/tests/postgres/error.rs +++ b/tests/postgres/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, postgres::Postgres, Connection}; +use sqlx::{error::ErrorKind, postgres::Postgres, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,15 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} diff --git a/tests/sqlite/error.rs b/tests/sqlite/error.rs index 1f6b797e69..2227a14d3b 100644 --- a/tests/sqlite/error.rs +++ b/tests/sqlite/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Executor}; +use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Error, Executor}; use sqlx_test::new; #[sqlx_macros::test] @@ -70,3 +70,15 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} From 92e71f4b78f8c6a998312b8410cd0ec76d209832 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:37:25 -0800 Subject: [PATCH 06/13] chore: add tests of Error::InvalidSavePointStatement --- tests/mysql/error.rs | 14 ++++++++++++++ tests/postgres/error.rs | 14 ++++++++++++++ tests/sqlite/error.rs | 14 ++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index 090cbe1980..3ee1024fc8 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -86,3 +86,17 @@ async fn it_fails_with_begin_failed() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/postgres/error.rs b/tests/postgres/error.rs index 5e52155f33..32bf814770 100644 --- a/tests/postgres/error.rs +++ b/tests/postgres/error.rs @@ -86,3 +86,17 @@ async fn it_fails_with_begin_failed() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/sqlite/error.rs b/tests/sqlite/error.rs index 2227a14d3b..8729842b70 100644 --- a/tests/sqlite/error.rs +++ b/tests/sqlite/error.rs @@ -82,3 +82,17 @@ async fn it_fails_with_begin_failed() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} From 7ab05a62c7eba5f725a254a427035a31035ffba7 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Wed, 27 Nov 2024 18:50:02 -0800 Subject: [PATCH 07/13] chore: test begin_with works for all SQLite "BEGIN" statements --- sqlx-sqlite/src/connection/mod.rs | 23 +++++++++++++++++++++ sqlx-sqlite/src/lib.rs | 4 +++- tests/sqlite/sqlite.rs | 33 +++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 53c3156e9d..d7cd48fe25 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -508,6 +508,29 @@ impl LockedSqliteHandle<'_> { let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; ret == 0 } + + /// Calls `sqlite3_txn_state` on this handle. + pub fn transaction_state(&mut self) -> Result { + use libsqlite3_sys::{ + sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE, + }; + + let state = + match unsafe { sqlite3_txn_state(self.as_raw_handle().as_ptr(), std::ptr::null()) } { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => return Err(Error::Protocol("Invalid transaction state".into())), + }; + Ok(state) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SqliteTransactionState { + None, + Read, + Write, } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f8f5534879..398ecf59e8 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,7 +46,9 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; -pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; +pub use connection::{ + LockedSqliteHandle, SqliteConnection, SqliteOperation, SqliteTransactionState, UpdateHookResult, +}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index b733ccbb4c..11a582c370 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -960,3 +960,36 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[sqlx_macros::test] +async fn it_can_use_transaction_options() -> anyhow::Result<()> { + use sqlx_sqlite::SqliteTransactionState; + + async fn check_txn_state( + conn: &mut SqliteConnection, + expected: SqliteTransactionState, + ) -> Result<(), sqlx::Error> { + let state = conn.lock_handle().await?.transaction_state()?; + assert_eq!(state, expected); + Ok(()) + } + + let mut conn = new::().await?; + + check_txn_state(&mut conn, SqliteTransactionState::None).await?; + + let mut tx = conn.begin_with("BEGIN DEFERRED").await?; + check_txn_state(&mut *tx, SqliteTransactionState::None).await?; + drop(tx); + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + drop(tx); + + // Note: may result in database locked errors if tests are run in parallel + let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; + check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + drop(tx); + + Ok(()) +} From 2b1c34852c68a6583e929355f3646a91083a4247 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 19:03:24 -0800 Subject: [PATCH 08/13] chore: improve comment on Connection::begin_with --- sqlx-core/src/connection.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index de0a05799d..dd9c974fdc 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -53,6 +53,9 @@ pub trait Connection: Send { /// Begin a new transaction with a custom statement. /// /// Returns a [`Transaction`] for controlling and tracking the new transaction. + /// + /// Returns an error if the connection is already in a transaction or if + /// `statement` does not put the connection into a transaction. fn begin_with( &mut self, statement: impl Into>, From 6a78180922ef593b4ec79d309b13d9ab29449e9e Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 19:04:33 -0800 Subject: [PATCH 09/13] feat: add default impl of `Connection::begin_with` This makes the new method a non-breaking change. --- sqlx-core/src/connection.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index dd9c974fdc..ba226bc814 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -61,7 +61,10 @@ pub trait Connection: Send { statement: impl Into>, ) -> BoxFuture<'_, Result, Error>> where - Self: Sized; + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) + } /// Execute the function inside a transaction. /// From 2705f206a69a40596afccaafc41cd2dbbcb31b1c Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 18:40:17 -0800 Subject: [PATCH 10/13] refactor: combine if statement + unwrap_or_else into one match --- sqlx-mysql/src/transaction.rs | 11 +++++++---- sqlx-postgres/src/transaction.rs | 11 +++++++---- sqlx-sqlite/src/connection/worker.rs | 15 ++++++++------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 953735bf9a..11f56c0cb9 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -22,10 +22,13 @@ impl TransactionManager for MySqlTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - if statement.is_some() && depth > 0 { - return Err(Error::InvalidSavePointStatement); - } - let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in a transaction + // (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; conn.execute(&*statement).await?; if !conn.in_transaction() { return Err(Error::BeginFailed); diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index ec01129d6f..f70961cc19 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -20,10 +20,13 @@ impl TransactionManager for PgTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - if statement.is_some() && depth > 0 { - return Err(Error::InvalidSavePointStatement); - } - let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in + // a transaction (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; let rollback = Rollback::new(conn); rollback.conn.queue_simple_query(&statement)?; diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index ff908001aa..c8e6f0a268 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -184,17 +184,18 @@ impl ConnectionWorker { Command::Begin { tx, statement } => { let depth = conn.transaction_depth; - let statement = if depth == 0 { - statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)) - } else { - if statement.is_some() { + let statement = match statement { + // custom `BEGIN` statements are not allowed if + // we're already in a transaction (we need to + // issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => { if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { break; } continue; - } - - begin_ansi_transaction_sql(depth) + }, + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), }; let res = From ca0843bc273ab0e5399ee6a4738478720caf066f Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 18:59:42 -0800 Subject: [PATCH 11/13] feat: use in-memory SQLite DB to avoid conflicts across tests run in parallel --- tests/sqlite/sqlite.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 11a582c370..232e58167d 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -974,21 +974,24 @@ async fn it_can_use_transaction_options() -> anyhow::Result<()> { Ok(()) } - let mut conn = new::().await?; + let mut conn = SqliteConnectOptions::new() + .in_memory(true) + .connect() + .await + .unwrap(); check_txn_state(&mut conn, SqliteTransactionState::None).await?; let mut tx = conn.begin_with("BEGIN DEFERRED").await?; - check_txn_state(&mut *tx, SqliteTransactionState::None).await?; + check_txn_state(&mut tx, SqliteTransactionState::None).await?; drop(tx); let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; - check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await?; drop(tx); - // Note: may result in database locked errors if tests are run in parallel let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; - check_txn_state(&mut *tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await?; drop(tx); Ok(()) From 9756385357abea751d6207d42659f6dfc82cf683 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 18:55:29 -0800 Subject: [PATCH 12/13] feedback: remove public wrapper for sqlite3_txn_state Move the wrapper directly into the test that uses it instead. --- Cargo.toml | 1 + sqlx-sqlite/src/connection/mod.rs | 23 ------------------ sqlx-sqlite/src/lib.rs | 4 +--- tests/sqlite/sqlite.rs | 39 +++++++++++++++++++++---------- 4 files changed, 29 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bf0a867e1e..23e77bd1fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -184,6 +184,7 @@ rand_xoshiro = "0.6.0" hex = "0.4.3" tempfile = "3.10.1" criterion = { version = "0.5.1", features = ["async_tokio"] } +libsqlite3-sys = { version = "0.30.1" } # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index d7cd48fe25..53c3156e9d 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -508,29 +508,6 @@ impl LockedSqliteHandle<'_> { let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; ret == 0 } - - /// Calls `sqlite3_txn_state` on this handle. - pub fn transaction_state(&mut self) -> Result { - use libsqlite3_sys::{ - sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE, - }; - - let state = - match unsafe { sqlite3_txn_state(self.as_raw_handle().as_ptr(), std::ptr::null()) } { - SQLITE_TXN_NONE => SqliteTransactionState::None, - SQLITE_TXN_READ => SqliteTransactionState::Read, - SQLITE_TXN_WRITE => SqliteTransactionState::Write, - _ => return Err(Error::Protocol("Invalid transaction state".into())), - }; - Ok(state) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum SqliteTransactionState { - None, - Read, - Write, } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index 398ecf59e8..f8f5534879 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,9 +46,7 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; -pub use connection::{ - LockedSqliteHandle, SqliteConnection, SqliteOperation, SqliteTransactionState, UpdateHookResult, -}; +pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; pub use database::Sqlite; pub use error::SqliteError; pub use options::{ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 232e58167d..55b2630f90 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -6,6 +6,7 @@ use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; +use sqlx_sqlite::LockedSqliteHandle; use sqlx_test::new; use std::sync::Arc; @@ -963,15 +964,9 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res #[sqlx_macros::test] async fn it_can_use_transaction_options() -> anyhow::Result<()> { - use sqlx_sqlite::SqliteTransactionState; - - async fn check_txn_state( - conn: &mut SqliteConnection, - expected: SqliteTransactionState, - ) -> Result<(), sqlx::Error> { - let state = conn.lock_handle().await?.transaction_state()?; + async fn check_txn_state(conn: &mut SqliteConnection, expected: SqliteTransactionState) { + let state = transaction_state(&mut conn.lock_handle().await.unwrap()); assert_eq!(state, expected); - Ok(()) } let mut conn = SqliteConnectOptions::new() @@ -980,19 +975,39 @@ async fn it_can_use_transaction_options() -> anyhow::Result<()> { .await .unwrap(); - check_txn_state(&mut conn, SqliteTransactionState::None).await?; + check_txn_state(&mut conn, SqliteTransactionState::None).await; let mut tx = conn.begin_with("BEGIN DEFERRED").await?; - check_txn_state(&mut tx, SqliteTransactionState::None).await?; + check_txn_state(&mut tx, SqliteTransactionState::None).await; drop(tx); let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; - check_txn_state(&mut tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; drop(tx); let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; - check_txn_state(&mut tx, SqliteTransactionState::Write).await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; drop(tx); Ok(()) } + +fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState { + use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE}; + + let unchecked_state = + unsafe { sqlite3_txn_state(handle.as_raw_handle().as_ptr(), std::ptr::null()) }; + match unchecked_state { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => panic!("unknown txn state: {unchecked_state}"), + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum SqliteTransactionState { + None, + Read, + Write, +} From 7a4987f230a739df24c81b43d7903d8472da5e86 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Tue, 3 Dec 2024 14:05:13 -0800 Subject: [PATCH 13/13] fix: cache Status on MySqlConnection --- sqlx-mysql/src/connection/executor.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 07c7979b08..169dee76b7 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -166,6 +166,8 @@ impl MySqlConnection { // this indicates either a successful query with no rows at all or a failed query let ok = packet.ok()?; + self.inner.status_flags = ok.status; + let rows_affected = ok.affected_rows; logger.increase_rows_affected(rows_affected); let done = MySqlQueryResult { @@ -208,6 +210,8 @@ impl MySqlConnection { if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.inner.stream.capabilities)?; + self.inner.status_flags = eof.status; + r#yield!(Either::Left(MySqlQueryResult { rows_affected: 0, last_insert_id: 0,