OpenAuth.Net/OpenAuth.Repository/BaseRepository.cs
2020-12-29 23:52:06 +08:00

321 lines
10 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using System;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using Infrastructure;
using Microsoft.EntityFrameworkCore;
using OpenAuth.Repository.Core;
using OpenAuth.Repository.Interface;
using Z.EntityFramework.Plus;
namespace OpenAuth.Repository
{
public class BaseRepository<T,TDbContext> : IRepository<T,TDbContext> where T : BaseEntity where TDbContext: DbContext
{
private TDbContext _context;
public BaseRepository(TDbContext context)
{
_context = context;
}
/// <summary>
/// 根据过滤条件,获取记录
/// </summary>
/// <param name="exp">The exp.</param>
public IQueryable<T> Find(Expression<Func<T, bool>> exp = null)
{
return Filter(exp);
}
public bool Any(Expression<Func<T, bool>> exp)
{
return _context.Set<T>().Any(exp);
}
/// <summary>
/// 查找单个,且不被上下文所跟踪
/// </summary>
public T FirstOrDefault(Expression<Func<T, bool>> exp)
{
return _context.Set<T>().AsNoTracking().FirstOrDefault(exp);
}
/// <summary>
/// 得到分页记录
/// </summary>
/// <param name="pageindex">The pageindex.</param>
/// <param name="pagesize">The pagesize.</param>
/// <param name="orderby">排序,格式如:"Id"/"Id descending"</param>
public IQueryable<T> Find(int pageindex, int pagesize, string orderby = "",
Expression<Func<T, bool>> exp = null)
{
if (pageindex < 1) pageindex = 1;
if (string.IsNullOrEmpty(orderby))
orderby = "Id descending";
return Filter(exp).OrderBy(orderby).Skip(pagesize * (pageindex - 1)).Take(pagesize);
}
/// <summary>
/// 根据过滤条件获取记录数
/// </summary>
public int Count(Expression<Func<T, bool>> exp = null)
{
return Filter(exp).Count();
}
public void Add(T entity)
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
_context.Set<T>().Add(entity);
Save();
_context.Entry(entity).State = EntityState.Detached;
}
/// <summary>
/// 批量添加
/// </summary>
/// <param name="entities">The entities.</param>
public void BatchAdd(T[] entities)
{
foreach (var entity in entities)
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
}
_context.Set<T>().AddRange(entities);
Save();
}
public void Update(T entity)
{
var entry = this._context.Entry(entity);
entry.State = EntityState.Modified;
//如果数据没有发生变化
if (!this._context.ChangeTracker.HasChanges())
{
return;
}
Save();
}
public void Delete(T entity)
{
_context.Set<T>().Remove(entity);
Save();
}
/// <summary>
/// 实现按需要只更新部分更新
/// <para>如Update(u =>u.Id==1,u =>new User{Name="ok"});</para>
/// </summary>
/// <param name="where">The where.</param>
/// <param name="entity">The entity.</param>
public void Update(Expression<Func<T, bool>> where, Expression<Func<T, T>> entity)
{
_context.Set<T>().Where(where).Update(entity);
}
public virtual void Delete(Expression<Func<T, bool>> exp)
{
_context.Set<T>().Where(exp).Delete();
}
public void Save()
{
try
{
var entities = _context.ChangeTracker.Entries()
.Where(e => e.State == EntityState.Added
|| e.State == EntityState.Modified)
.Select(e => e.Entity);
foreach (var entity in entities)
{
var validationContext = new ValidationContext(entity);
Validator.ValidateObject(entity, validationContext, validateAllProperties: true);
}
_context.SaveChanges();
}
catch (ValidationException exc)
{
Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}");
throw (exc.InnerException as Exception ?? exc);
}
catch (Exception ex) //DbUpdateException
{
throw (ex.InnerException as Exception ?? ex);
}
}
private IQueryable<T> Filter(Expression<Func<T, bool>> exp)
{
var dbSet = _context.Set<T>().AsNoTracking().AsQueryable();
if (exp != null)
dbSet = dbSet.Where(exp);
return dbSet;
}
public int ExecuteSqlRaw(string sql)
{
return _context.Database.ExecuteSqlRaw(sql);
}
/// <summary>
/// 使用SQL脚本查询
/// </summary>
/// <typeparam name="T"> T为数据库实体</typeparam>
/// <returns></returns>
public IQueryable<T> FromSql(string sql, params object[] parameters)
{
return _context.Set<T>().FromSqlRaw(sql, parameters);
}
/// <summary>
/// 使用SQL脚本查询
/// </summary>
/// <typeparam name="T"> T为非数据库实体需要在DbContext中增加对应的DbQuery</typeparam>
/// <returns></returns>
public IQueryable<T> Query(string sql, params object[] parameters)
{
return _context.Query<T>().FromSqlRaw(sql, parameters);
}
#region
/// <summary>
/// 异步执行sql
/// </summary>
/// <param name="sql"></param>
/// <returns></returns>
public async Task<int> ExecuteSqlRawAsync(string sql)
{
return await _context.Database.ExecuteSqlRawAsync(sql);
}
public async Task<int> AddAsync(T entity)
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
_context.Set<T>().Add(entity);
return await SaveAsync();
//_context.Entry(entity).State = EntityState.Detached;
}
public async Task<int> BatchAddAsync(T[] entities)
{
foreach (var entity in entities)
{
if (entity.KeyIsNull())
{
entity.GenerateDefaultKeyVal();
}
}
await _context.Set<T>().AddRangeAsync(entities);
return await SaveAsync();
}
/// <summary>
/// 异步更新
/// </summary>
/// <param name="entity"></param>
/// <returns></returns>
public async Task<int> UpdateAsync(T entity)
{
var entry = this._context.Entry(entity);
entry.State = EntityState.Modified;
//如果数据没有发生变化
if (!this._context.ChangeTracker.HasChanges())
{
return 0;
}
return await SaveAsync();
}
/// <summary>
/// 异步删除
/// </summary>
/// <param name="entity"></param>
/// <returns></returns>
public async Task<int> DeleteAsync(T entity)
{
_context.Set<T>().Remove(entity);
return await SaveAsync();
}
/// <summary>
/// 异步保存
/// </summary>
/// <returns></returns>
/// <exception cref="Exception"></exception>
public async Task<int> SaveAsync()
{
try
{
var entities = _context.ChangeTracker.Entries()
.Where(e => e.State == EntityState.Added
|| e.State == EntityState.Modified)
.Select(e => e.Entity);
foreach (var entity in entities)
{
var validationContext = new ValidationContext(entity);
Validator.ValidateObject(entity, validationContext, validateAllProperties: true);
}
return await _context.SaveChangesAsync();
}
catch (ValidationException exc)
{
Console.WriteLine($"{nameof(Save)} validation exception: {exc?.Message}");
throw (exc.InnerException as Exception ?? exc);
}
catch (Exception ex) //DbUpdateException
{
throw (ex.InnerException as Exception ?? ex);
}
}
/// <summary>
/// 根据过滤条件获取记录数
/// </summary>
public async Task<int> CountAsync(Expression<Func<T, bool>> exp = null)
{
return await Filter(exp).CountAsync();
}
public async Task<bool> AnyAsync(Expression<Func<T, bool>> exp)
{
return await _context.Set<T>().AnyAsync(exp);
}
/// <summary>
/// 查找单个,且不被上下文所跟踪
/// </summary>
public async Task<T> FirstOrDefaultAsync(Expression<Func<T, bool>> exp)
{
return await _context.Set<T>().AsNoTracking().FirstOrDefaultAsync(exp);
}
#endregion
}
}