Skip to content

Commit

Permalink
Support transactions across global and local contexts
Browse files Browse the repository at this point in the history
Enhanced transaction handling by introducing a global transaction manager check within the `DocumentSet` methods. The methods now consider current global or local transactions, ensuring consistent behavior when interacting with MongoDB collections.
  • Loading branch information
kerem-acer committed Nov 11, 2024
1 parent e92490f commit 5eb0755
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
25 changes: 10 additions & 15 deletions src/Core/MongoVault.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,15 @@ public DocumentSet<TDocument> Set<TDocument>(bool ignoreQueryFilter = false)
return new DocumentSet<TDocument>(this, ignoreQueryFilter);
}

public bool IsInTransaction => _transaction is not null;
public bool IsInTransaction => _transaction is not null || GlobalTransactionManager?.CurrentTransaction is not null;

public IMongoVaultTransaction? CurrentTransaction => _transaction ?? GlobalTransactionManager?.CurrentTransaction;

public IMongoVaultTransaction BeginTransaction()
{
if (_transaction is not null)
if (IsInTransaction)
{
throw new InvalidOperationException("Transaction already started");
}

if (GlobalTransactionManager?.CurrentTransaction is not null)
{
throw new InvalidOperationException("BeginTransaction cannot be called inside a global transaction.");
throw new InvalidOperationException("BeginTransaction cannot be called inside a transaction.");
}

_transaction = new MongoVaultTransaction(this, MongoDatabase.Client.StartSession());
Expand Down Expand Up @@ -86,10 +83,8 @@ public virtual async Task<int> SaveAsync(CancellationToken cancellationToken = d
return 0;
}

var transaction = GlobalTransactionManager?.CurrentTransaction ?? _transaction;

var session = transaction is not null
? transaction.Session
var session = CurrentTransaction is not null
? CurrentTransaction.Session
: await MongoDatabase.Client.StartSessionAsync(cancellationToken: cancellationToken);

if (!session.IsInTransaction)
Expand Down Expand Up @@ -121,7 +116,7 @@ public virtual async Task<int> SaveAsync(CancellationToken cancellationToken = d
await interceptor.SavedChangesAsync(interceptorContext, affected, cancellationToken);
}

if (transaction is null)
if (CurrentTransaction is null)
{
await session.CommitTransactionAsync(cancellationToken);
}
Expand All @@ -133,7 +128,7 @@ public virtual async Task<int> SaveAsync(CancellationToken cancellationToken = d
await interceptor.SaveChangesFailedAsync(e, interceptorContext, cancellationToken);
}

if (transaction is null)
if (CurrentTransaction is null)
{
await session.AbortTransactionAsync(cancellationToken);
}
Expand All @@ -142,7 +137,7 @@ public virtual async Task<int> SaveAsync(CancellationToken cancellationToken = d
}
finally
{
if (transaction is null)
if (CurrentTransaction is null)
{
session.Dispose();
}
Expand Down
31 changes: 24 additions & 7 deletions src/Core/Set/DocumentSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,52 @@ public DocumentSet(MongoVault vault, bool ignoreQueryFilter = false)

public IFindFluent<TDocument, TDocument> Find(Expression<Func<TDocument, bool>> filter)
{
return _collection.Find(TransformQueryFilterExpression(filter));
var transformedFilter = TransformQueryFilterExpression(filter);

return _vault.CurrentTransaction is not null ?
_collection.Find(_vault.CurrentTransaction.Session, transformedFilter)
: _collection.Find(transformedFilter);
}

public IFindFluent<TDocument, TDocument> Find(FilterDefinition<TDocument> filter)
{
var queryFilter = TransformQueryFilterExpression();
var finalFilter = queryFilter is not null ? filter & queryFilter : filter;

return _collection.Find(finalFilter);
return _vault.CurrentTransaction is not null ?
_collection.Find(_vault.CurrentTransaction.Session, finalFilter)
: _collection.Find(finalFilter);
}

public IFindFluent<TDocument, TDocument> Find()
{
var filter = TransformQueryFilterExpression();
var filter = TransformQueryFilterExpression() ?? Builders<TDocument>.Filter.Empty;

return _collection.Find(filter ?? Builders<TDocument>.Filter.Empty);
return _vault.CurrentTransaction is not null ?
_collection.Find(_vault.CurrentTransaction.Session, filter)
: _collection.Find(filter);
}

public IQueryable<TDocument> AsQueryable()
{
var filter = TransformQueryFilterExpression();

var queryable = _collection.AsQueryable();
var queryable = _vault.CurrentTransaction is not null ?
_collection.AsQueryable(_vault.CurrentTransaction.Session)
: _collection.AsQueryable();

return filter is not null ? queryable.Where(filter) : queryable;
}

public IAggregateFluent<TDocument> Aggregate()
{
var filter = TransformQueryFilterExpression() ?? Builders<TDocument>.Filter.Empty;

var aggregate = _vault.CurrentTransaction is not null ?
_collection.Aggregate(_vault.CurrentTransaction.Session)
: _collection.Aggregate();

return _collection.Aggregate().Match(filter);
return aggregate.Match(filter);
}

public void Add(TDocument document)
Expand Down Expand Up @@ -110,8 +124,11 @@ public void UpdateByKey(object key, UpdateDefinition<TDocument> update)
public async Task<TDocument?> GetByKeyAsync(object key, CancellationToken cancellationToken = default)
{
var filter = BuildKeyFilter(key);
var find = _vault.CurrentTransaction is not null ?
_collection.Find(_vault.CurrentTransaction.Session, filter)
: _collection.Find(filter);

return await _collection.Find(filter).FirstOrDefaultAsync(cancellationToken);
return await find.FirstOrDefaultAsync(cancellationToken);
}

public DocumentSet<TDocument> IgnoreQueryFilter()
Expand Down

0 comments on commit 5eb0755

Please sign in to comment.