using System;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using Infrastructure;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using OpenAuth.Repository.Core;
using OpenAuth.Repository.Interface;
using Z.EntityFramework.Plus;
namespace OpenAuth.Repository
{
public class UnitWork: IUnitWork
{
private OpenAuthDBContext _context;
public UnitWork(OpenAuthDBContext context)
{
_context = context;
}
///
/// EF默认情况下,每调用一次SaveChanges()都会执行一个单独的事务
/// 本接口实现在一个事务中可以多次执行SaveChanges()方法
///
public void ExecuteWithTransaction(Action action)
{
using (IDbContextTransaction transaction = _context.Database.BeginTransaction())
{
try
{
action();
transaction.Commit();
}
catch (Exception ex)
{
transaction.Rollback();
throw ex;
}
}
}
///
/// 返回DbContext,用于多线程等极端情况
///
public OpenAuthDBContext GetDbContext()
{
return _context;
}
///
/// 根据过滤条件,获取记录
///
/// The exp.
public IQueryable Find(Expression> exp = null) where T : class
{
return Filter(exp);
}
public bool Any(Expression> exp) where T : class
{
return _context.Set().Any(exp);
}
///
/// 查找单个
///
public T FirstOrDefault(Expression> exp) where T:class
{
return _context.Set().AsNoTracking().FirstOrDefault(exp);
}
///
/// 得到分页记录
///
/// The pageindex.
/// The pagesize.
/// 排序,格式如:"Id"/"Id descending"
public IQueryable Find(int pageindex, int pagesize, string orderby = "", Expression> exp = null) where T : class
{
if (pageindex < 1) pageindex = 1;
if (string.IsNullOrEmpty(orderby))
orderby = "Id descending";
return Filter(exp).OrderBy(orderby).Skip(pagesize * (pageindex - 1)).Take(pagesize);
}
///
/// 根据过滤条件获取记录数
///
public int Count(Expression> exp = null) where T : class
{
return Filter(exp).Count();
}
///
/// 新增对象,如果Id为空,则会自动创建默认Id
///
public void Add(T entity) where T : BaseEntity
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
_context.Set().Add(entity);
}
///
/// 批量新增对象,如果对象Id为空,则会自动创建默认Id
///
public void BatchAdd(T[] entities) where T : BaseEntity
{
foreach (var entity in entities)
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
}
_context.Set().AddRange(entities);
}
public void Update(T entity) where T:class
{
var entry = this._context.Entry(entity);
entry.State = EntityState.Modified;
//如果数据没有发生变化
if (!this._context.ChangeTracker.HasChanges())
{
entry.State = EntityState.Unchanged;
}
}
public void Delete(T entity) where T:class
{
_context.Set().Remove(entity);
}
///
/// 实现按需要只更新部分更新
/// 如:Update<User>(u =>u.Id==1,u =>new User{Name="ok"})
/// 该方法内部自动调用了SaveChanges(),需要ExecuteWithTransaction配合才能实现事务控制
///
/// 更新条件
/// 更新后的实体
public void Update(Expression> where, Expression> entity) where T:class
{
_context.Set().Where(where).Update(entity);
}
///
/// 批量删除
/// 该方法内部自动调用了SaveChanges(),需要ExecuteWithTransaction配合才能实现事务控制
///
public virtual void Delete(Expression> exp) where T : class
{
_context.Set().Where(exp).Delete();
}
public void Save()
{
try
{
var entities = _context.ChangeTracker.Entries()
.Where(e => e.State == EntityState.Added
|| e.State == EntityState.Modified)
.Select(e => e.Entity);
foreach (var entity in entities)
{
var validationContext = new ValidationContext(entity);
Validator.ValidateObject(entity, validationContext, validateAllProperties: true);
}
_context.SaveChanges();
}
catch (ValidationException exc)
{
Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}");
throw (exc.InnerException as Exception ?? exc);
}
catch (Exception ex) //DbUpdateException
{
throw (ex.InnerException as Exception ?? ex);
}
}
private IQueryable Filter(Expression> exp) where T : class
{
var dbSet = _context.Set().AsNoTracking().AsQueryable();
if (exp != null)
dbSet = dbSet.Where(exp);
return dbSet;
}
public int ExecuteSql(string sql)
{
return _context.Database.ExecuteSqlRaw(sql);
}
public IQueryable FromSql(string sql, params object[] parameters) where T : class
{
return _context.Set().FromSqlRaw(sql, parameters);
}
public IQueryable Query(string sql, params object[] parameters) where T : class
{
return _context.Query().FromSqlRaw(sql, parameters);
}
#region 异步实现
///
/// 异步执行sql
///
///
///
public async Task ExecuteSqlRawAsync(string sql)
{
return await _context.Database.ExecuteSqlRawAsync(sql);
}
///
/// 异步保存
///
///
///
public async Task SaveAsync()
{
try
{
var entities = _context.ChangeTracker.Entries()
.Where(e => e.State == EntityState.Added
|| e.State == EntityState.Modified)
.Select(e => e.Entity);
foreach (var entity in entities)
{
var validationContext = new ValidationContext(entity);
Validator.ValidateObject(entity, validationContext, validateAllProperties: true);
}
return await _context.SaveChangesAsync();
}
catch (ValidationException exc)
{
Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}");
throw (exc.InnerException as Exception ?? exc);
}
catch (Exception ex) //DbUpdateException
{
throw (ex.InnerException as Exception ?? ex);
}
}
///
/// 根据过滤条件获取记录数
///
public async Task CountAsync(Expression> exp = null) where T : class
{
return await Filter(exp).CountAsync();
}
public async Task AnyAsync(Expression> exp) where T : class
{
return await _context.Set().AnyAsync(exp);
}
///
/// 查找单个,且不被上下文所跟踪
///
public async Task FirstOrDefaultAsync(Expression> exp) where T : class
{
return await _context.Set().AsNoTracking().FirstOrDefaultAsync(exp);
}
#endregion
}
}