using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; using System.Data; using System.Data.Common; using System.Linq; using System.Linq.Expressions; using System.Threading.Tasks; using Infrastructure; using Infrastructure.Database; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Storage; using OpenAuth.Repository.Core; using OpenAuth.Repository.Interface; using Z.EntityFramework.Plus; namespace OpenAuth.Repository { public class UnitWork: IUnitWork where TDbContext: DbContext { private TDbContext _context; public UnitWork(TDbContext context) { _context = context; } /// /// EF默认情况下,每调用一次SaveChanges()都会执行一个单独的事务 /// 本接口实现在一个事务中可以多次执行SaveChanges()方法 /// public void ExecuteWithTransaction(Action action) { using (IDbContextTransaction transaction = _context.Database.BeginTransaction()) { try { action(); transaction.Commit(); } catch (Exception ex) { transaction.Rollback(); throw ex; } } } /// /// ExecuteWithTransaction方法的异步方式 /// EF默认情况下,每调用一次SaveChanges()都会执行一个单独的事务 /// 本接口实现在一个事务中可以多次执行SaveChanges()方法 /// public async Task ExecuteWithTransactionAsync(Func action) { using (IDbContextTransaction transaction = _context.Database.BeginTransaction()) { try { await action(); transaction.Commit(); } catch (Exception ex) { transaction.Rollback(); throw ex; } } } /// /// 返回DbContext,用于多线程等极端情况 /// public DbContext GetDbContext() { return _context; } /// /// 根据过滤条件,获取记录 /// /// The exp. public IQueryable Find(Expression> exp = null) where T : class { return Filter(exp); } public bool Any(Expression> exp) where T : class { return _context.Set().Any(exp); } /// /// 查找单个 /// public T FirstOrDefault(Expression> exp) where T:class { return _context.Set().AsNoTracking().FirstOrDefault(exp); } /// /// 得到分页记录 /// /// The pageindex. /// The pagesize. /// 排序,格式如:"Id"/"Id descending" public IQueryable Find(int pageindex, int pagesize, string orderby = "", Expression> exp = null) where T : class { if (pageindex < 1) pageindex = 1; if (string.IsNullOrEmpty(orderby)) orderby = "Id descending"; return Filter(exp).OrderBy(orderby).Skip(pagesize * (pageindex - 1)).Take(pagesize); } /// /// 根据过滤条件获取记录数 /// public int Count(Expression> exp = null) where T : class { return Filter(exp).Count(); } /// /// 新增对象,如果Id为空,则会自动创建默认Id /// public void Add(T entity) where T : BaseEntity { if (entity.KeyIsNull()) { entity.GenerateDefaultKeyVal(); } _context.Set().Add(entity); } /// /// 批量新增对象,如果对象Id为空,则会自动创建默认Id /// public void BatchAdd(T[] entities) where T : BaseEntity { foreach (var entity in entities) { if (entity.KeyIsNull()) { entity.GenerateDefaultKeyVal(); } } _context.Set().AddRange(entities); } public void Update(T entity) where T:class { var entry = this._context.Entry(entity); entry.State = EntityState.Modified; //如果数据没有发生变化 if (!this._context.ChangeTracker.HasChanges()) { entry.State = EntityState.Unchanged; } } public void Delete(T entity) where T:class { _context.Set().Remove(entity); } /// /// 实现按需要只更新部分更新 /// 如:Update<User>(u =>u.Id==1,u =>new User{Name="ok"}) /// 该方法内部自动调用了SaveChanges(),需要ExecuteWithTransaction配合才能实现事务控制 /// /// 更新条件 /// 更新后的实体 public void Update(Expression> where, Expression> entity) where T:class { _context.Set().Where(where).Update(entity); } /// /// 批量删除 /// 该方法内部自动调用了SaveChanges(),需要ExecuteWithTransaction配合才能实现事务控制 /// public virtual void Delete(Expression> exp) where T : class { _context.Set().Where(exp).Delete(); } public void Save() { try { var entities = _context.ChangeTracker.Entries() .Where(e => e.State == EntityState.Added || e.State == EntityState.Modified) .Select(e => e.Entity); foreach (var entity in entities) { var validationContext = new ValidationContext(entity); Validator.ValidateObject(entity, validationContext, validateAllProperties: true); } _context.SaveChanges(); } catch (ValidationException exc) { Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}"); throw (exc.InnerException as Exception ?? exc); } catch (Exception ex) //DbUpdateException { throw (ex.InnerException as Exception ?? ex); } } private IQueryable Filter(Expression> exp) where T : class { var dbSet = _context.Set().AsNoTracking().AsQueryable(); if (exp != null) dbSet = dbSet.Where(exp); return dbSet; } public int ExecuteSql(string sql) { if (string.IsNullOrEmpty(sql)) return 0; return _context.Database.ExecuteSqlRaw(sql); } public IQueryable FromSql(string sql, params object[] parameters) where T : class { return _context.Set().FromSqlRaw(sql, parameters); } [Obsolete("最新版同FromSql,需要在DbContext中设置modelBuilder.Entity().HasNoKey();")] public IQueryable Query(string sql, params object[] parameters) where T : class { return _context.Set().FromSqlRaw(sql, parameters); } /// /// 执行存储过程 /// /// 存储过程名称 /// 存储过程参数 public List ExecProcedure(string procName, params DbParameter[] sqlParams) where T : class { var connection = _context.Database.GetDbConnection(); using (var cmd = connection.CreateCommand()) { _context.Database.OpenConnection(); cmd.CommandText = procName; cmd.CommandType = CommandType.StoredProcedure; cmd.Parameters.AddRange(sqlParams); DbDataReader dr = cmd.ExecuteReader(); var datatable = new DataTable(); datatable.Load(dr); return datatable.ToList(); } } #region 异步实现 /// /// 异步执行sql /// /// /// public async Task ExecuteSqlRawAsync(string sql) { return await _context.Database.ExecuteSqlRawAsync(sql); } /// /// 异步保存 /// /// /// public async Task SaveAsync() { try { var entities = _context.ChangeTracker.Entries() .Where(e => e.State == EntityState.Added || e.State == EntityState.Modified) .Select(e => e.Entity); foreach (var entity in entities) { var validationContext = new ValidationContext(entity); Validator.ValidateObject(entity, validationContext, validateAllProperties: true); } return await _context.SaveChangesAsync(); } catch (ValidationException exc) { Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}"); throw (exc.InnerException as Exception ?? exc); } catch (Exception ex) //DbUpdateException { throw (ex.InnerException as Exception ?? ex); } } /// /// 根据过滤条件获取记录数 /// public async Task CountAsync(Expression> exp = null) where T : class { return await Filter(exp).CountAsync(); } public async Task AnyAsync(Expression> exp) where T : class { return await _context.Set().AnyAsync(exp); } /// /// 查找单个,且不被上下文所跟踪 /// public async Task FirstOrDefaultAsync(Expression> exp) where T : class { return await _context.Set().AsNoTracking().FirstOrDefaultAsync(exp); } #endregion } }