diff --git a/pgdog/src/backend/replication/logical/error.rs b/pgdog/src/backend/replication/logical/error.rs index 3f9bbf910..030548da3 100644 --- a/pgdog/src/backend/replication/logical/error.rs +++ b/pgdog/src/backend/replication/logical/error.rs @@ -81,9 +81,6 @@ pub enum Error { #[error("shard {0} has no replication slot")] NoReplicationSlot(usize), - #[error("parallel connection error")] - ParallelConnection, - #[error("no replicas available for table sync")] NoReplicas, diff --git a/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs b/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs index c0cb3ef6d..f28b94aca 100644 --- a/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs +++ b/pgdog/src/backend/replication/logical/publisher/parallel_sync.rs @@ -44,7 +44,7 @@ impl ParallelSync { .permit .acquire() .await - .map_err(|_| Error::ParallelConnection)?; + .map_err(|_| Error::DataSyncAborted)?; if self.tx.is_closed() { return Err(Error::DataSyncAborted); @@ -64,9 +64,7 @@ impl ParallelSync { } }; - self.tx - .send(result) - .map_err(|_| Error::ParallelConnection)?; + self.tx.send(result).map_err(|_| Error::DataSyncAborted)?; Ok::<(), Error>(()) }) diff --git a/pgdog/src/backend/replication/logical/subscriber/copy.rs b/pgdog/src/backend/replication/logical/subscriber/copy.rs index 262db9d58..869a63fd3 100644 --- a/pgdog/src/backend/replication/logical/subscriber/copy.rs +++ b/pgdog/src/backend/replication/logical/subscriber/copy.rs @@ -1,12 +1,13 @@ //! Shard COPY stream from one source //! between N shards. +use futures::future::try_join_all; use pg_query::{parse_raw, NodeEnum}; use pgdog_config::QueryParserEngine; use tracing::debug; use crate::{ - backend::{replication::subscriber::ParallelConnection, Cluster, ConnectReason}, + backend::{Cluster, ConnectReason, Server}, config::Role, frontend::router::parser::{CopyParser, Shard}, net::{CopyData, CopyDone, ErrorResponse, FromBytes, Protocol, Query, ToBytes}, @@ -24,7 +25,7 @@ pub struct CopySubscriber { copy: CopyParser, cluster: Cluster, buffer: Vec, - connections: Vec, + connections: Vec, stmt: CopyStatement, bytes_sharded: usize, } @@ -85,7 +86,7 @@ impl CopySubscriber { .1 .standalone(ConnectReason::Replication) .await?; - servers.push(ParallelConnection::new(primary)?); + servers.push(primary); } self.connections = servers; @@ -95,38 +96,38 @@ impl CopySubscriber { /// Disconnect from all shards. pub async fn disconnect(&mut self) -> Result<(), Error> { - for conn in std::mem::take(&mut self.connections) { - conn.reattach().await?; - } + self.connections.clear(); Ok(()) } /// Start COPY on all shards. pub async fn start_copy(&mut self) -> Result<(), Error> { - let stmt = Query::new(self.stmt.copy_in()); - if self.connections.is_empty() { self.connect().await?; } - for server in &mut self.connections { - debug!("{} [{}]", stmt.query(), server.addr()); + let stmt = Query::new(self.stmt.copy_in()); - server.send_one(&stmt.clone().into()).await?; - server.flush().await?; + // Start COPY IN on all shards concurrently. + try_join_all(self.connections.iter_mut().map(|server| { + let msg: crate::net::ProtocolMessage = stmt.clone().into(); + debug!("{} [{}]", stmt.query(), server.addr()); - let msg = server.read().await?; - match msg.code() { - 'G' => (), - 'E' => { - return Err(Error::PgError(Box::new(ErrorResponse::from_bytes( - msg.to_bytes()?, - )?))) + async move { + server.send_one(&msg).await?; + server.flush().await?; + let reply = server.read().await?; + match reply.code() { + 'G' => Ok(()), + 'E' => Err(Error::PgError(Box::new(ErrorResponse::from_bytes( + reply.to_bytes()?, + )?))), + c => Err(Error::OutOfSync(c)), } - c => return Err(Error::OutOfSync(c)), } - } + })) + .await?; Ok(()) } @@ -135,20 +136,20 @@ impl CopySubscriber { pub async fn copy_done(&mut self) -> Result<(), Error> { self.flush().await?; - for server in &mut self.connections { + // Finalise COPY on all shards concurrently. + try_join_all(self.connections.iter_mut().map(|server| async move { server.send_one(&CopyDone.into()).await?; server.flush().await?; - let command_complete = server.read().await?; - match command_complete.code() { + let cc = server.read().await?; + match cc.code() { 'E' => { - let error = ErrorResponse::from_bytes(command_complete.to_bytes()?)?; + let error = ErrorResponse::from_bytes(cc.to_bytes()?)?; if error.code == "08P01" && error.message == "insufficient data left in message" { return Err(Error::BinaryFormatMismatch(Box::new(error))); - } else { - return Err(Error::PgError(Box::new(error))); } + return Err(Error::PgError(Box::new(error))); } 'C' => (), c => return Err(Error::OutOfSync(c)), @@ -158,7 +159,9 @@ impl CopySubscriber { if rfq.code() != 'Z' { return Err(Error::OutOfSync(rfq.code())); } - } + Ok(()) + })) + .await?; Ok(()) } @@ -174,12 +177,19 @@ impl CopySubscriber { } async fn flush(&mut self) -> Result<(usize, usize), Error> { + if self.buffer.is_empty() { + return Ok((0, 0)); + } + let result = self.copy.shard(&self.buffer)?; self.buffer.clear(); let rows = result.len(); let bytes = result.iter().map(|row| row.len()).sum::(); + self.bytes_sharded += bytes; + // Route each row to the right shard(s). send_one is a buffered write + // so this loop does no I/O — no concurrency needed here. for row in &result { for (shard, server) in self.connections.iter_mut().enumerate() { match row.shard() { @@ -198,7 +208,8 @@ impl CopySubscriber { } } - self.bytes_sharded += result.iter().map(|c| c.len()).sum::(); + // Flush all shards concurrently — this is the actual socket write. + try_join_all(self.connections.iter_mut().map(|s| s.flush())).await?; Ok((rows, bytes)) } diff --git a/pgdog/src/backend/replication/logical/subscriber/mod.rs b/pgdog/src/backend/replication/logical/subscriber/mod.rs index 9df30191b..7b5e7e0b2 100644 --- a/pgdog/src/backend/replication/logical/subscriber/mod.rs +++ b/pgdog/src/backend/replication/logical/subscriber/mod.rs @@ -1,6 +1,5 @@ pub mod context; pub mod copy; -pub mod parallel_connection; pub mod stream; #[cfg(test)] @@ -8,5 +7,4 @@ mod tests; pub use context::StreamContext; pub use copy::CopySubscriber; -pub use parallel_connection::ParallelConnection; pub use stream::StreamSubscriber; diff --git a/pgdog/src/backend/replication/logical/subscriber/parallel_connection.rs b/pgdog/src/backend/replication/logical/subscriber/parallel_connection.rs deleted file mode 100644 index 4c4b15713..000000000 --- a/pgdog/src/backend/replication/logical/subscriber/parallel_connection.rs +++ /dev/null @@ -1,256 +0,0 @@ -//! Postgres server connection running in its own task. -//! -//! This allows to queue up messages across multiple instances -//! of this connection without blocking and while maintaining protocol integrity. -//! -use tokio::select; -use tokio::spawn; -use tokio::sync::{ - mpsc::{channel, Receiver, Sender}, - Notify, -}; - -use crate::backend::pool::Address; -use crate::{ - backend::Server, - frontend::ClientRequest, - net::{Message, ProtocolMessage}, -}; - -use std::sync::Arc; - -use super::super::Error; - -// What we can send. -enum ParallelMessage { - // Protocol message, e.g. Bind, Execute, Sync. - ProtocolMessage(ProtocolMessage), - // Flush the socket. - Flush, -} - -// What we can receive. -enum ParallelReply { - // Message, e.g. RowDescription, DataRow, CommandComplete, etc. - Message(Message), - // The task gives back the server connection to the owner. - // Preserve connections between parallel executions. - Server(Box), -} - -// Parallel Postgres server connection. -#[derive(Debug)] -pub struct ParallelConnection { - tx: Sender, - rx: Receiver, - stop: Arc, - address: Address, -} - -impl ParallelConnection { - // Queue up message to server. - pub async fn send_one(&mut self, message: &ProtocolMessage) -> Result<(), Error> { - self.tx - .send(ParallelMessage::ProtocolMessage(message.clone())) - .await - .map_err(|_| Error::ParallelConnection)?; - - Ok(()) - } - - // Queue up the contents of the buffer. - pub async fn send(&mut self, client_request: &ClientRequest) -> Result<(), Error> { - for message in client_request.messages.iter() { - self.tx - .send(ParallelMessage::ProtocolMessage(message.clone())) - .await - .map_err(|_| Error::ParallelConnection)?; - self.tx - .send(ParallelMessage::Flush) - .await - .map_err(|_| Error::ParallelConnection)?; - } - - Ok(()) - } - - // Wait for a message from the server. - pub async fn read(&mut self) -> Result { - let reply = self.rx.recv().await.ok_or(Error::ParallelConnection)?; - match reply { - ParallelReply::Message(message) => Ok(message), - ParallelReply::Server(_) => Err(Error::ParallelConnection), - } - } - - // Request server connection performs socket flush. - pub async fn flush(&mut self) -> Result<(), Error> { - self.tx - .send(ParallelMessage::Flush) - .await - .map_err(|_| Error::ParallelConnection)?; - - Ok(()) - } - - /// Server address. - pub fn addr(&self) -> &Address { - &self.address - } - - // Move server connection into its own Tokio task. - pub fn new(server: Server) -> Result { - // Ideally we don't hardcode these. PgDog - // can use a lot of memory if this is high. - let (tx1, rx1) = channel(4096); - let (tx2, rx2) = channel(4096); - let stop = Arc::new(Notify::new()); - let address = server.addr().clone(); - - let listener = Listener { - stop: stop.clone(), - rx: rx1, - tx: tx2, - server: Some(Box::new(server)), - }; - - spawn(async move { - listener.run().await?; - - Ok::<(), Error>(()) - }); - - Ok(Self { - address, - tx: tx1, - rx: rx2, - stop, - }) - } - - // Get the connection back from the async task. This will - // only work if the connection is idle (ReadyForQuery received, no more traffic expected). - pub async fn reattach(mut self) -> Result { - self.stop.notify_one(); - let server = self.rx.recv().await.ok_or(Error::ParallelConnection)?; - match server { - ParallelReply::Server(server) => Ok(*server), - _ => Err(Error::ParallelConnection), - } - } -} - -// Stop the background task and kill the connection. -// Prevents leaks in case the connection is not "reattached". -impl Drop for ParallelConnection { - fn drop(&mut self) { - self.stop.notify_one(); - } -} - -// Background task performing the actual work of talking to Postgres. -struct Listener { - rx: Receiver, - tx: Sender, - server: Option>, - stop: Arc, -} - -impl Listener { - // Send message to Postgres. - async fn send(&mut self, message: ProtocolMessage) -> Result<(), Error> { - if let Some(ref mut server) = self.server { - server.send_one(&message).await?; - } - - Ok(()) - } - - // Flush socket. - async fn flush(&mut self) -> Result<(), Error> { - if let Some(ref mut server) = self.server { - server.flush().await?; - } - - Ok(()) - } - - // Return server to parent task. - async fn return_server(&mut self) -> Result<(), Error> { - if let Some(server) = self.server.take() { - if self.tx.is_closed() { - drop(server); - } else { - let _ = self.tx.send(ParallelReply::Server(server)).await; - } - } - - Ok(()) - } - - // Run the background task. - async fn run(mut self) -> Result<(), Error> { - loop { - select! { - message = self.rx.recv() => { - if let Some(message) = message { - match message { - ParallelMessage::ProtocolMessage(message) => self.send(message).await?, - ParallelMessage::Flush => self.flush().await?, - } - } else { - self.return_server().await?; - break; - } - } - - reply = self.server.as_mut().unwrap().read() => { - let reply = reply?; - self.tx.send(ParallelReply::Message(reply)).await.map_err(|_| Error::ParallelConnection)?; - } - - _ = self.stop.notified() => { - self.return_server().await?; - break; - } - } - } - - Ok(()) - } -} - -#[cfg(test)] -mod test { - use crate::{ - backend::server::test::test_server, - net::{Parse, Protocol, Sync}, - }; - - use super::*; - - #[tokio::test] - async fn test_parallel_connection() { - let server = test_server().await; - let mut parallel = ParallelConnection::new(server).unwrap(); - - parallel - .send( - &vec![ - Parse::named("test", "SELECT $1::bigint").into(), - Sync.into(), - ] - .into(), - ) - .await - .unwrap(); - - for c in ['1', 'Z'] { - let msg = parallel.read().await.unwrap(); - assert_eq!(msg.code(), c); - } - - let server = parallel.reattach().await.unwrap(); - assert!(server.in_sync()); - } -}