OpenAuth.Net/OpenAuth.Repository/UnitWork.cs
2022-11-26 22:01:35 +08:00

334 lines
11 KiB
C#
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Data;
using System.Data.Common;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using Infrastructure;
using Infrastructure.Database;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using OpenAuth.Repository.Core;
using OpenAuth.Repository.Interface;
using Z.EntityFramework.Plus;
namespace OpenAuth.Repository
{
public class UnitWork<TDbContext>: IUnitWork<TDbContext> where TDbContext: DbContext
{
private TDbContext _context;
public UnitWork(TDbContext context)
{
_context = context;
}
/// <summary>
/// EF默认情况下每调用一次SaveChanges()都会执行一个单独的事务
/// 本接口实现在一个事务中可以多次执行SaveChanges()方法
/// </summary>
public void ExecuteWithTransaction(Action action)
{
using (IDbContextTransaction transaction = _context.Database.BeginTransaction())
{
try
{
action();
transaction.Commit();
}
catch (Exception ex)
{
transaction.Rollback();
throw ex;
}
}
}
/// <summary>
/// ExecuteWithTransaction方法的异步方式
/// EF默认情况下每调用一次SaveChanges()都会执行一个单独的事务
/// 本接口实现在一个事务中可以多次执行SaveChanges()方法
/// </summary>
public async Task ExecuteWithTransactionAsync(Func<Task> action)
{
using (IDbContextTransaction transaction = _context.Database.BeginTransaction())
{
try
{
await action();
transaction.Commit();
}
catch (Exception ex)
{
transaction.Rollback();
throw ex;
}
}
}
/// <summary>
/// 返回DbContext,用于多线程等极端情况
/// </summary>
public DbContext GetDbContext()
{
return _context;
}
/// <summary>
/// 根据过滤条件,获取记录
/// </summary>
/// <param name="exp">The exp.</param>
public IQueryable<T> Find<T>(Expression<Func<T, bool>> exp = null) where T : class
{
return Filter(exp);
}
public bool Any<T>(Expression<Func<T, bool>> exp) where T : class
{
return _context.Set<T>().Any(exp);
}
/// <summary>
/// 查找单个
/// </summary>
public T FirstOrDefault<T>(Expression<Func<T, bool>> exp) where T:class
{
return _context.Set<T>().AsNoTracking().FirstOrDefault(exp);
}
/// <summary>
/// 得到分页记录
/// </summary>
/// <param name="pageindex">The pageindex.</param>
/// <param name="pagesize">The pagesize.</param>
/// <param name="orderby">排序,格式如:"Id"/"Id descending"</param>
public IQueryable<T> Find<T>(int pageindex, int pagesize, string orderby = "", Expression<Func<T, bool>> exp = null) where T : class
{
if (pageindex < 1) pageindex = 1;
if (string.IsNullOrEmpty(orderby))
orderby = "Id descending";
return Filter(exp).OrderBy(orderby).Skip(pagesize * (pageindex - 1)).Take(pagesize);
}
/// <summary>
/// 根据过滤条件获取记录数
/// </summary>
public int Count<T>(Expression<Func<T, bool>> exp = null) where T : class
{
return Filter(exp).Count();
}
/// <summary>
/// 新增对象如果Id为空则会自动创建默认Id
/// </summary>
public void Add<T>(T entity) where T : BaseEntity
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
_context.Set<T>().Add(entity);
}
/// <summary>
/// 批量新增对象如果对象Id为空则会自动创建默认Id
/// </summary>
public void BatchAdd<T>(T[] entities) where T : BaseEntity
{
foreach (var entity in entities)
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
}
_context.Set<T>().AddRange(entities);
}
public void Update<T>(T entity) where T:class
{
var entry = this._context.Entry(entity);
entry.State = EntityState.Modified;
//如果数据没有发生变化
if (!this._context.ChangeTracker.HasChanges())
{
entry.State = EntityState.Unchanged;
}
}
public void Delete<T>(T entity) where T:class
{
_context.Set<T>().Remove(entity);
}
/// <summary>
/// 实现按需要只更新部分更新
/// <para>如Update&lt;User&gt;(u =>u.Id==1,u =>new User{Name="ok"})</para>
/// <para>该方法内部自动调用了SaveChanges()需要ExecuteWithTransaction配合才能实现事务控制</para>
/// </summary>
/// <param name="where">更新条件</param>
/// <param name="entity">更新后的实体</param>
public void Update<T>(Expression<Func<T, bool>> where, Expression<Func<T, T>> entity) where T:class
{
_context.Set<T>().Where(where).Update(entity);
}
/// <summary>
/// 批量删除
/// <para>该方法内部自动调用了SaveChanges()需要ExecuteWithTransaction配合才能实现事务控制</para>
/// </summary>
public virtual void Delete<T>(Expression<Func<T, bool>> exp) where T : class
{
_context.Set<T>().Where(exp).Delete();
}
public void Save()
{
try
{
var entities = _context.ChangeTracker.Entries()
.Where(e => e.State == EntityState.Added
|| e.State == EntityState.Modified)
.Select(e => e.Entity);
foreach (var entity in entities)
{
var validationContext = new ValidationContext(entity);
Validator.ValidateObject(entity, validationContext, validateAllProperties: true);
}
_context.SaveChanges();
}
catch (ValidationException exc)
{
Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}");
throw (exc.InnerException as Exception ?? exc);
}
catch (Exception ex) //DbUpdateException
{
throw (ex.InnerException as Exception ?? ex);
}
}
private IQueryable<T> Filter<T>(Expression<Func<T, bool>> exp) where T : class
{
var dbSet = _context.Set<T>().AsNoTracking().AsQueryable();
if (exp != null)
dbSet = dbSet.Where(exp);
return dbSet;
}
public int ExecuteSql(string sql)
{
if (string.IsNullOrEmpty(sql)) return 0;
return _context.Database.ExecuteSqlRaw(sql);
}
public IQueryable<T> FromSql<T>(string sql, params object[] parameters) where T : class
{
return _context.Set<T>().FromSqlRaw(sql, parameters);
}
[Obsolete("最新版同FromSql需要在DbContext中设置modelBuilder.Entity<XX>().HasNoKey();")]
public IQueryable<T> Query<T>(string sql, params object[] parameters) where T : class
{
return _context.Set<T>().FromSqlRaw(sql, parameters);
}
/// <summary>
/// 执行存储过程
/// </summary>
/// <param name="procName">存储过程名称</param>
/// <param name="sqlParams">存储过程参数</param>
public List<T> ExecProcedure<T>(string procName, params DbParameter[] sqlParams) where T : class
{
var connection = _context.Database.GetDbConnection();
using (var cmd = connection.CreateCommand())
{
_context.Database.OpenConnection();
cmd.CommandText = procName;
cmd.CommandType = CommandType.StoredProcedure;
cmd.Parameters.AddRange(sqlParams);
DbDataReader dr = cmd.ExecuteReader();
var datatable = new DataTable();
datatable.Load(dr);
return datatable.ToList<T>();
}
}
#region
/// <summary>
/// 异步执行sql
/// </summary>
/// <param name="sql"></param>
/// <returns></returns>
public async Task<int> ExecuteSqlRawAsync(string sql)
{
return await _context.Database.ExecuteSqlRawAsync(sql);
}
/// <summary>
/// 异步保存
/// </summary>
/// <returns></returns>
/// <exception cref="Exception"></exception>
public async Task<int> SaveAsync()
{
try
{
var entities = _context.ChangeTracker.Entries()
.Where(e => e.State == EntityState.Added
|| e.State == EntityState.Modified)
.Select(e => e.Entity);
foreach (var entity in entities)
{
var validationContext = new ValidationContext(entity);
Validator.ValidateObject(entity, validationContext, validateAllProperties: true);
}
return await _context.SaveChangesAsync();
}
catch (ValidationException exc)
{
Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}");
throw (exc.InnerException as Exception ?? exc);
}
catch (Exception ex) //DbUpdateException
{
throw (ex.InnerException as Exception ?? ex);
}
}
/// <summary>
/// 根据过滤条件获取记录数
/// </summary>
public async Task<int> CountAsync<T>(Expression<Func<T, bool>> exp = null) where T : class
{
return await Filter(exp).CountAsync();
}
public async Task<bool> AnyAsync<T>(Expression<Func<T, bool>> exp) where T : class
{
return await _context.Set<T>().AnyAsync(exp);
}
/// <summary>
/// 查找单个,且不被上下文所跟踪
/// </summary>
public async Task<T> FirstOrDefaultAsync<T>(Expression<Func<T, bool>> exp) where T : class
{
return await _context.Set<T>().AsNoTracking().FirstOrDefaultAsync(exp);
}
#endregion
}
}