fix batch

This commit is contained in:
Looly 2019-09-22 00:05:48 +08:00
parent c9f77e1746
commit 0e453b099a
8 changed files with 163 additions and 88 deletions

View File

@ -18,6 +18,8 @@
* 【db】 StatementUtil增加setParam方法
* 【db】 Entity.fieldList改为有序实现
* 【crypto】 AES、DES增加对ZeroPadding的支持issue#551@Github
* 【db】 优化批量插入代码减少类型判断导致的性能问题issue#I12B4Z@Gitee
* 【db】 优化SQL日志格式和日志显示
### Bug修复
* 【core】 修复DateUtil.offset导致的时区错误问题issue#I1294O@Gitee

View File

@ -257,6 +257,6 @@ public final class DbUtil {
* @since 4.1.7
*/
public static void setShowSqlGlobal(boolean isShowSql, boolean isFormatSql, boolean isShowParams, Level level) {
SqlLog.INSTASNCE.init(isShowSql, isFormatSql, isShowParams, level);
SqlLog.INSTANCE.init(isShowSql, isFormatSql, isShowParams, level);
}
}

View File

@ -1,13 +1,5 @@
package cn.hutool.db;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Collection;
import java.util.List;
import javax.sql.DataSource;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil;
@ -25,6 +17,13 @@ import cn.hutool.db.sql.SqlExecutor;
import cn.hutool.db.sql.SqlUtil;
import cn.hutool.db.sql.Wrapper;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Collection;
import java.util.List;
/**
* SQL执行类<br>
* 此执行类只接受方言参数不需要数据源只有在执行方法时需要数据库连接对象<br>
@ -155,7 +154,7 @@ public class SqlConnRunner{
if(ArrayUtil.isEmpty(records)){
return new int[]{0};
}
//单条单独处理
if(1 == records.length) {
return new int[] { insert(conn, records[0])};

View File

@ -1,20 +1,7 @@
package cn.hutool.db;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import cn.hutool.core.collection.ArrayIter;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
@ -22,6 +9,11 @@ import cn.hutool.db.sql.SqlBuilder;
import cn.hutool.db.sql.SqlLog;
import cn.hutool.db.sql.SqlUtil;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.*;
import java.util.*;
/**
* Statement和PreparedStatement工具类
*
@ -38,6 +30,9 @@ public class StatementUtil {
* @throws SQLException SQL执行异常
*/
public static PreparedStatement fillParams(PreparedStatement ps, Object... params) throws SQLException {
if (ArrayUtil.isEmpty(params)) {
return ps;
}
return fillParams(ps, new ArrayIter<>(params));
}
@ -51,13 +46,28 @@ public class StatementUtil {
* @throws SQLException SQL执行异常
*/
public static PreparedStatement fillParams(PreparedStatement ps, Iterable<?> params) throws SQLException {
if (ArrayUtil.isEmpty(params)) {
return fillParams(ps, params, null);
}
/**
* 填充SQL的参数<br>
* 对于日期对象特殊处理传入java.util.Date默认按照Timestamp处理
*
* @param ps PreparedStatement
* @param params SQL参数
* @param nullTypeCache null参数的类型缓存避免循环中重复获取类型
* @return {@link PreparedStatement}
* @throws SQLException SQL执行异常
* @since 4.6.7
*/
public static PreparedStatement fillParams(PreparedStatement ps, Iterable<?> params, Map<Integer, Integer> nullTypeCache) throws SQLException {
if (null == params) {
return ps;// 无参数
}
int paramIndex = 1;//第一个参数从1计数
for (Object param : params) {
setParam(ps, paramIndex++, param);
setParam(ps, paramIndex++, param, nullTypeCache);
}
return ps;
}
@ -103,7 +113,7 @@ public class StatementUtil {
Assert.notBlank(sql, "Sql String must be not blank!");
sql = sql.trim();
SqlLog.INSTASNCE.log(sql, params);
SqlLog.INSTANCE.log(sql, ArrayUtil.isEmpty(params) ? null : params);
PreparedStatement ps;
if (StrUtil.startWithIgnoreCase(sql, "insert")) {
// 插入默认返回主键
@ -142,7 +152,7 @@ public class StatementUtil {
Assert.notBlank(sql, "Sql String must be not blank!");
sql = sql.trim();
SqlLog.INSTASNCE.log(sql, paramsBatch);
SqlLog.INSTANCE.log(sql, paramsBatch);
PreparedStatement ps = conn.prepareStatement(sql);
for (Object[] params : paramsBatch) {
StatementUtil.fillParams(ps, params);
@ -151,6 +161,32 @@ public class StatementUtil {
return ps;
}
/**
* 创建批量操作的{@link PreparedStatement}
*
* @param conn 数据库连接
* @param sql SQL语句使用"?"做为占位符
* @param fields 字段列表用于获取对应值
* @param entities "?"对应参数批次列表
* @return {@link PreparedStatement}
* @throws SQLException SQL异常
* @since 4.6.7
*/
public static PreparedStatement prepareStatementForBatch(Connection conn, String sql, List<String> fields, Entity... entities) throws SQLException {
Assert.notBlank(sql, "Sql String must be not blank!");
sql = sql.trim();
SqlLog.INSTANCE.logForBatch(sql);
PreparedStatement ps = conn.prepareStatement(sql);
//null参数的类型缓存避免循环中重复获取类型
final Map<Integer, Integer> nullTypeMap = new HashMap<>();
for (Entity entity : entities) {
StatementUtil.fillParams(ps, CollectionUtil.valuesOfKeys(entity, fields), nullTypeMap);
ps.addBatch();
}
return ps;
}
/**
* 创建{@link CallableStatement}
*
@ -165,7 +201,7 @@ public class StatementUtil {
Assert.notBlank(sql, "Sql String must be not blank!");
sql = sql.trim();
SqlLog.INSTASNCE.log(sql, params);
SqlLog.INSTANCE.log(sql, params);
final CallableStatement call = conn.prepareCall(sql);
fillParams(call, params);
return call;
@ -247,9 +283,31 @@ public class StatementUtil {
* @since 4.6.7
*/
public static void setParam(PreparedStatement ps, int paramIndex, Object param) throws SQLException {
setParam(ps, paramIndex, param, null);
}
//--------------------------------------------------------------------------------------------- Private method start
/**
* {@link PreparedStatement} 设置单个参数
*
* @param ps {@link PreparedStatement}
* @param paramIndex 参数位置从1开始
* @param param 参数不能为{@code null}
* @param nullTypeCache 用于缓存参数为null位置的类型避免重复获取
* @throws SQLException SQL异常
* @since 4.6.7
*/
private static void setParam(PreparedStatement ps, int paramIndex, Object param, Map<Integer, Integer> nullTypeCache) throws SQLException {
if (null == param) {
ps.setNull(paramIndex, getTypeOfNull(ps, paramIndex));
return;
Integer type = (null == nullTypeCache) ? null : nullTypeCache.get(paramIndex);
if (null == type) {
type = getTypeOfNull(ps, paramIndex);
if (null != nullTypeCache) {
nullTypeCache.put(paramIndex, type);
}
}
ps.setNull(paramIndex, type);
}
// 日期特殊处理默认按照时间戳传入避免毫秒丢失
@ -282,4 +340,5 @@ public class StatementUtil {
// 其它参数类型
ps.setObject(paramIndex, param);
}
//--------------------------------------------------------------------------------------------- Private method end
}

View File

@ -3,6 +3,9 @@ package cn.hutool.db.dialect.impl;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.Assert;
@ -53,15 +56,9 @@ public class AnsiSqlDialect implements Dialect {
if (ArrayUtil.isEmpty(entities)) {
throw new DbRuntimeException("Entities for batch insert is empty !");
}
// 批量
// 批量根据第一行数据结构生成SQL占位符
final SqlBuilder insert = SqlBuilder.create(wrapper).insert(entities[0], this.dialectName());
final PreparedStatement ps = StatementUtil.prepareStatement(conn, insert.build());
for (Entity entity : entities) {
StatementUtil.fillParams(ps, CollectionUtil.valuesOfKeys(entity, insert.getFields()));
ps.addBatch();
}
return ps;
return StatementUtil.prepareStatementForBatch(conn, insert.build(), insert.getFields(), entities);
}
@Override

View File

@ -55,7 +55,7 @@ public class SqlBuilder implements Builder<String>{
* @author Looly
*
*/
public static enum Join {
public enum Join {
/** 如果表中有至少一个匹配,则返回行 */
INNER,
/** 即使右表中没有匹配,也从左表返回所有的行 */
@ -69,9 +69,9 @@ public class SqlBuilder implements Builder<String>{
final private StringBuilder sql = new StringBuilder();
/** 字段列表(仅用于插入和更新) */
final private List<String> fields = new ArrayList<String>();
final private List<String> fields = new ArrayList<>();
/** 占位符对应的值列表 */
final private List<Object> paramValues = new ArrayList<Object>();
final private List<Object> paramValues = new ArrayList<>();
/** 包装器 */
private Wrapper wrapper;
@ -527,22 +527,22 @@ public class SqlBuilder implements Builder<String>{
/**
* 获得插入或更新的数据库字段列表
*
* @return 插入或更新的数据库字段列表
*/
public List<String> getFields() {
return this.fields;
}
/**
* 获得插入或更新的数据库字段列表
*
*
* @return 插入或更新的数据库字段列表
*/
public String[] getFieldArray() {
return this.fields.toArray(new String[this.fields.size()]);
}
/**
* 获得插入或更新的数据库字段列表
*
* @return 插入或更新的数据库字段列表
*/
public List<String> getFields() {
return this.fields;
}
/**
* 获得占位符对应的值列表<br>
*

View File

@ -11,7 +11,7 @@ import cn.hutool.log.level.Level;
* @since 4.1.0
*/
public enum SqlLog {
INSTASNCE;
INSTANCE;
private final static Log log = LogFactory.get();
@ -40,16 +40,38 @@ public enum SqlLog {
/**
* 打印SQL日志
*
*
* @param sql SQL语句
* @since 4.6.7
*/
public void log(String sql) {
log(sql, null);
}
/**
* 打印批量 SQL日志
*
* @param sql SQL语句
* @since 4.6.7
*/
public void logForBatch(String sql) {
if (this.showSql) {
log.log(this.level, "\n[Batch SQL] -> {}", this.formatSql ? SqlFormatter.format(sql) : sql);
}
}
/**
* 打印SQL日志
*
* @param sql SQL语句
* @param paramValues 参数可为null
*/
public void log(String sql, Object paramValues) {
if (this.showSql) {
if (this.showParams) {
log.log(this.level, "\nSQL -> {}\nParams -> {}", this.formatSql ? SqlFormatter.format(sql) : sql, paramValues);
if (null != paramValues && this.showParams) {
log.log(this.level, "\n[SQL] -> {}\nParams -> {}", this.formatSql ? SqlFormatter.format(sql) : sql, paramValues);
} else {
log.log(this.level, "\nSQL -> {}", this.formatSql ? SqlFormatter.format(sql) : sql);
log.log(this.level, "\n[SQL] -> {}", this.formatSql ? SqlFormatter.format(sql) : sql);
}
}
}

View File

@ -1,29 +1,24 @@
package cn.hutool.db;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Console;
import cn.hutool.db.handler.EntityListHandler;
import cn.hutool.db.pojo.User;
import cn.hutool.db.sql.Condition;
import cn.hutool.db.sql.Condition.LikeType;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.SQLException;
import java.util.List;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Console;
import cn.hutool.db.ActiveEntity;
import cn.hutool.db.Db;
import cn.hutool.db.Entity;
import cn.hutool.db.handler.EntityListHandler;
import cn.hutool.db.pojo.User;
import cn.hutool.db.sql.Condition;
import cn.hutool.db.sql.Condition.LikeType;
/**
* 增删改查测试
*
* @author looly
*
* @author looly
*/
public class CRUDTest {
@ -34,13 +29,13 @@ public class CRUDTest {
List<Entity> results = db.findAll(Entity.create("user").set("age", "is null"));
Assert.assertEquals(0, results.size());
}
@Test
public void findIsNullTest2() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("age", "= null"));
Assert.assertEquals(0, results.size());
}
@Test
public void findIsNullTest3() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("age", null));
@ -52,13 +47,13 @@ public class CRUDTest {
List<Entity> results = db.findAll(Entity.create("user").set("age", "between '18' and '40'"));
Assert.assertEquals(1, results.size());
}
@Test
public void findByBigIntegerTest() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("age", new BigInteger("12")));
Assert.assertEquals(2, results.size());
}
@Test
public void findByBigDecimalTest() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("age", new BigDecimal("12")));
@ -70,31 +65,31 @@ public class CRUDTest {
List<Entity> results = db.findAll(Entity.create("user").set("name", "like \"%三%\""));
Assert.assertEquals(2, results.size());
}
@Test
public void findLikeTest2() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("name", new Condition("name", "", LikeType.Contains)));
Assert.assertEquals(2, results.size());
}
@Test
public void findLikeTest3() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("name", new Condition("name", null, LikeType.Contains)));
Assert.assertEquals(0, results.size());
}
@Test
public void findInTest() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("id", "in 1,2,3"));
Assert.assertEquals(2, results.size());
}
@Test
public void findInTest2() throws SQLException {
List<Entity> results = db.findAll(Entity.create("user").set("id", new Condition("id", new long[] {1,2,3})));
List<Entity> results = db.findAll(Entity.create("user").set("id", new Condition("id", new long[]{1, 2, 3})));
Assert.assertEquals(2, results.size());
}
@Test
public void findAllTest() throws SQLException {
List<Entity> results = db.findAll("user");
@ -106,19 +101,19 @@ public class CRUDTest {
List<Entity> find = db.find(CollUtil.newArrayList("name AS name2"), Entity.create("user"), new EntityListHandler());
Assert.assertFalse(find.isEmpty());
}
@Test
public void findActiveTest() throws SQLException {
public void findActiveTest() {
ActiveEntity entity = new ActiveEntity(db, "user");
entity.setFieldNames("name AS name2").load();
Assert.assertEquals("user", entity.getTableName());
Assert.assertFalse(entity.isEmpty());
}
/**
* 对增删改查做单元测试
*
* @throws SQLException
*
* @throws SQLException SQL异常
*/
@Test
@Ignore
@ -159,6 +154,7 @@ public class CRUDTest {
user2.setGender(false);
Entity data1 = Entity.parse(user1);
data1.put("name", null);
Entity data2 = Entity.parse(user2);
Console.log(data1);