Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sinks, sources, and pipelines generic on their errors #66

Merged
merged 5 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions pg_replicate/src/pipeline/batching/data_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
pipeline::{
batching::stream::BatchTimeoutStream,
sinks::BatchSink,
sources::{postgres::CdcStreamError, Source, SourceError},
sources::{postgres::CdcStreamError, CommonSourceError, Source},
PipelineAction, PipelineError,
},
table::TableId,
Expand All @@ -35,18 +35,24 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {
}
}

async fn copy_table_schemas(&mut self) -> Result<(), PipelineError> {
async fn copy_table_schemas(&mut self) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let table_schemas = self.source.get_table_schemas();
let table_schemas = table_schemas.clone();

if !table_schemas.is_empty() {
self.sink.write_table_schemas(table_schemas).await?;
self.sink
.write_table_schemas(table_schemas)
.await
.map_err(PipelineError::Sink)?;
}

Ok(())
}

async fn copy_tables(&mut self, copied_tables: &HashSet<TableId>) -> Result<(), PipelineError> {
async fn copy_tables(
&mut self,
copied_tables: &HashSet<TableId>,
) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let start = Instant::now();
let table_schemas = self.source.get_table_schemas();

Expand All @@ -60,12 +66,16 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {
continue;
}

self.sink.truncate_table(table_schema.table_id).await?;
self.sink
.truncate_table(table_schema.table_id)
.await
.map_err(PipelineError::Sink)?;

let table_rows = self
.source
.get_table_copy_stream(&table_schema.table_name, &table_schema.column_schemas)
.await?;
.await
.map_err(PipelineError::Source)?;

let batch_timeout_stream =
BatchTimeoutStream::new(table_rows, self.batch_config.clone());
Expand All @@ -77,16 +87,23 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {
//TODO: Avoid a vec copy
let mut rows = Vec::with_capacity(batch.len());
for row in batch {
rows.push(row.map_err(SourceError::TableCopyStream)?);
rows.push(row.map_err(CommonSourceError::TableCopyStream)?);
}
self.sink
.write_table_rows(rows, table_schema.table_id)
.await?;
.await
.map_err(PipelineError::Sink)?;
}

self.sink.table_copied(table_schema.table_id).await?;
self.sink
.table_copied(table_schema.table_id)
.await
.map_err(PipelineError::Sink)?;
}
self.source.commit_transaction().await?;
self.source
.commit_transaction()
.await
.map_err(PipelineError::Source)?;

let end = Instant::now();
let seconds = (end - start).as_secs();
Expand All @@ -95,10 +112,17 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {
Ok(())
}

async fn copy_cdc_events(&mut self, last_lsn: PgLsn) -> Result<(), PipelineError> {
async fn copy_cdc_events(
&mut self,
last_lsn: PgLsn,
) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let mut last_lsn: u64 = last_lsn.into();
last_lsn += 1;
let cdc_events = self.source.get_cdc_stream(last_lsn.into()).await?;
let cdc_events = self
.source
.get_cdc_stream(last_lsn.into())
.await
.map_err(PipelineError::Source)?;

pin!(cdc_events);

Expand All @@ -117,13 +141,17 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {
{
continue;
}
let event = event.map_err(SourceError::CdcStream)?;
let event = event.map_err(CommonSourceError::CdcStream)?;
if let CdcEvent::KeepAliveRequested { reply } = event {
send_status_update = reply;
};
events.push(event);
}
let last_lsn = self.sink.write_cdc_events(events).await?;
let last_lsn = self
.sink
.write_cdc_events(events)
.await
.map_err(PipelineError::Sink)?;
if send_status_update {
info!("sending status update with lsn: {last_lsn}");
let inner = unsafe {
Expand All @@ -136,15 +164,20 @@ impl<Src: Source, Snk: BatchSink> BatchDataPipeline<Src, Snk> {
.as_mut()
.send_status_update(last_lsn)
.await
.map_err(|e| PipelineError::SourceError(SourceError::StatusUpdate(e)))?;
.map_err(CommonSourceError::StatusUpdate)?;
}
}

Ok(())
}

pub async fn start(&mut self) -> Result<(), PipelineError> {
let resumption_state = self.sink.get_resumption_state().await?;
pub async fn start(&mut self) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let resumption_state = self
.sink
.get_resumption_state()
.await
.map_err(PipelineError::Sink)?;

match self.action {
PipelineAction::TableCopiesOnly => {
self.copy_table_schemas().await?;
Expand Down
68 changes: 51 additions & 17 deletions pg_replicate/src/pipeline/data_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use futures::StreamExt;
use tokio::pin;
use tokio_postgres::types::PgLsn;

use crate::{conversions::cdc_event::CdcEvent, pipeline::sources::SourceError, table::TableId};
use crate::{
conversions::cdc_event::CdcEvent, pipeline::sources::CommonSourceError, table::TableId,
};

use super::{sinks::Sink, sources::Source, PipelineAction, PipelineError};

Expand All @@ -23,18 +25,24 @@ impl<Src: Source, Snk: Sink> DataPipeline<Src, Snk> {
}
}

async fn copy_table_schemas(&mut self) -> Result<(), PipelineError> {
async fn copy_table_schemas(&mut self) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let table_schemas = self.source.get_table_schemas();
let table_schemas = table_schemas.clone();

if !table_schemas.is_empty() {
self.sink.write_table_schemas(table_schemas).await?;
self.sink
.write_table_schemas(table_schemas)
.await
.map_err(PipelineError::Sink)?;
}

Ok(())
}

async fn copy_tables(&mut self, copied_tables: &HashSet<TableId>) -> Result<(), PipelineError> {
async fn copy_tables(
&mut self,
copied_tables: &HashSet<TableId>,
) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let table_schemas = self.source.get_table_schemas();

let mut keys: Vec<u32> = table_schemas.keys().copied().collect();
Expand All @@ -46,58 +54,84 @@ impl<Src: Source, Snk: Sink> DataPipeline<Src, Snk> {
continue;
}

self.sink.truncate_table(table_schema.table_id).await?;
self.sink
.truncate_table(table_schema.table_id)
.await
.map_err(PipelineError::Sink)?;

let table_rows = self
.source
.get_table_copy_stream(&table_schema.table_name, &table_schema.column_schemas)
.await?;
.await
.map_err(PipelineError::Source)?;

pin!(table_rows);

while let Some(row) = table_rows.next().await {
let row = row.map_err(SourceError::TableCopyStream)?;
let row = row.map_err(CommonSourceError::TableCopyStream)?;
self.sink
.write_table_row(row, table_schema.table_id)
.await?;
.await
.map_err(PipelineError::Sink)?;
}

self.sink.table_copied(table_schema.table_id).await?;
self.sink
.table_copied(table_schema.table_id)
.await
.map_err(PipelineError::Sink)?;
}
self.source.commit_transaction().await?;
self.source
.commit_transaction()
.await
.map_err(PipelineError::Source)?;

Ok(())
}

async fn copy_cdc_events(&mut self, last_lsn: PgLsn) -> Result<(), PipelineError> {
async fn copy_cdc_events(
&mut self,
last_lsn: PgLsn,
) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let mut last_lsn: u64 = last_lsn.into();
last_lsn += 1;
let cdc_events = self.source.get_cdc_stream(last_lsn.into()).await?;
let cdc_events = self
.source
.get_cdc_stream(last_lsn.into())
.await
.map_err(PipelineError::Source)?;

pin!(cdc_events);

while let Some(cdc_event) = cdc_events.next().await {
let cdc_event = cdc_event.map_err(SourceError::CdcStream)?;
let cdc_event = cdc_event.map_err(CommonSourceError::CdcStream)?;
let send_status_update = if let CdcEvent::KeepAliveRequested { reply } = cdc_event {
reply
} else {
false
};
let last_lsn = self.sink.write_cdc_event(cdc_event).await?;
let last_lsn = self
.sink
.write_cdc_event(cdc_event)
.await
.map_err(PipelineError::Sink)?;
if send_status_update {
cdc_events
.as_mut()
.send_status_update(last_lsn)
.await
.map_err(|e| PipelineError::SourceError(SourceError::StatusUpdate(e)))?;
.map_err(CommonSourceError::StatusUpdate)?;
}
}

Ok(())
}

pub async fn start(&mut self) -> Result<(), PipelineError> {
let resumption_state = self.sink.get_resumption_state().await?;
pub async fn start(&mut self) -> Result<(), PipelineError<Src::Error, Snk::Error>> {
let resumption_state = self
.sink
.get_resumption_state()
.await
.map_err(PipelineError::Sink)?;
match self.action {
PipelineAction::TableCopiesOnly => {
self.copy_table_schemas().await?;
Expand Down
21 changes: 12 additions & 9 deletions pg_replicate/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::collections::HashSet;

use sinks::SinkError;
use sources::SourceError;
use thiserror::Error;
use tokio_postgres::types::PgLsn;

use crate::table::TableId;

use self::{sinks::SinkError, sources::SourceError};

pub mod batching;
pub mod data_pipeline;
pub mod sinks;
Expand All @@ -19,16 +19,19 @@ pub enum PipelineAction {
Both,
}

pub struct PipelineResumptionState {
pub copied_tables: HashSet<TableId>,
pub last_lsn: PgLsn,
}

#[derive(Debug, Error)]
pub enum PipelineError {
pub enum PipelineError<SrcErr: SourceError, SnkErr: SinkError> {
#[error("source error: {0}")]
SourceError(#[from] SourceError),
Source(#[source] SrcErr),

#[error("sink error: {0}")]
SinkError(#[from] SinkError),
}
Sink(#[source] SnkErr),

pub struct PipelineResumptionState {
pub copied_tables: HashSet<TableId>,
pub last_lsn: PgLsn,
#[error("source error: {0}")]
CommonSource(#[from] sources::CommonSourceError),
}
15 changes: 9 additions & 6 deletions pg_replicate/src/pipeline/sinks/bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub enum BigQuerySinkError {
CommitWithoutBegin,
}

impl SinkError for BigQuerySinkError {}

pub struct BigQueryBatchSink {
client: BigQueryClient,
dataset_id: String,
Expand Down Expand Up @@ -84,7 +86,8 @@ impl BigQueryBatchSink {

#[async_trait]
impl BatchSink for BigQueryBatchSink {
async fn get_resumption_state(&mut self) -> Result<PipelineResumptionState, SinkError> {
type Error = BigQuerySinkError;
async fn get_resumption_state(&mut self) -> Result<PipelineResumptionState, Self::Error> {
info!("getting resumption state from bigquery");
let copied_table_column_schemas = [ColumnSchema {
name: "table_id".to_string(),
Expand Down Expand Up @@ -140,7 +143,7 @@ impl BatchSink for BigQueryBatchSink {
async fn write_table_schemas(
&mut self,
table_schemas: HashMap<TableId, TableSchema>,
) -> Result<(), SinkError> {
) -> Result<(), Self::Error> {
for table_schema in table_schemas.values() {
let table_name = Self::table_name_in_bq(&table_schema.table_name);
self.client
Expand All @@ -161,7 +164,7 @@ impl BatchSink for BigQueryBatchSink {
&mut self,
mut table_rows: Vec<TableRow>,
table_id: TableId,
) -> Result<(), SinkError> {
) -> Result<(), Self::Error> {
let table_schema = self.get_table_schema(table_id)?;
let table_name = Self::table_name_in_bq(&table_schema.table_name);
let table_descriptor = table_schema.into();
Expand All @@ -177,7 +180,7 @@ impl BatchSink for BigQueryBatchSink {
Ok(())
}

async fn write_cdc_events(&mut self, events: Vec<CdcEvent>) -> Result<PgLsn, SinkError> {
async fn write_cdc_events(&mut self, events: Vec<CdcEvent>) -> Result<PgLsn, Self::Error> {
let mut table_name_to_table_rows = HashMap::new();
let mut new_last_lsn = PgLsn::from(0);
let mut final_lsn: Option<PgLsn> = None;
Expand Down Expand Up @@ -243,14 +246,14 @@ impl BatchSink for BigQueryBatchSink {
Ok(committed_lsn)
}

async fn table_copied(&mut self, table_id: TableId) -> Result<(), SinkError> {
async fn table_copied(&mut self, table_id: TableId) -> Result<(), Self::Error> {
self.client
.insert_into_copied_tables(&self.dataset_id, table_id)
.await?;
Ok(())
}

async fn truncate_table(&mut self, _table_id: TableId) -> Result<(), SinkError> {
async fn truncate_table(&mut self, _table_id: TableId) -> Result<(), Self::Error> {
Ok(())
}
}
Loading