最近业务方有一个需求,需要一次导入超过100万数据到系统数据库。可能大家首先会想,这么大的数据,干嘛通过程序去实现导入,为什么不直接通过SQL导入到数据库。
大数据量报表导出请参考:Java实现大批量数据导入导出(100W以上) -(二)导出
一、为什么一定要在代码实现
说说为什么不能通过SQL直接导入到数据库,而是通过程序实现:
1. 首先,这个导入功能开始提供页面导入,只是开始业务方保证的一次只有<3W的数据导入;
2. 其次,业务方导入的内容需要做校验,比如门店号,商品号等是否系统存在,需要程序校验;
3. 最后,业务方导入的都是编码,数据库中还要存入对应名称,方便后期查询,SQL导入也是无法实现的。
基于以上上三点,就无法直接通过SQL语句导入数据库。那就只能老老实实的想办法通过程序实现。
二、程序实现有以下技术难点
1. 一次读取这么大的数据量,肯定会导致服务器内存溢出;
2. 调用接口保存一次传输数据量太大,网络传输压力会很大;
3. 最终通过SQL一次批量插入,对数据库压力也比较大,如果业务同时操作这个表数据,很容易造成死锁。
三、解决思路
根据列举的技术难点我的解决思路是:
1. 既然一次读取整个导入文件,那就先将文件流上传到服务器磁盘,然后分批从磁盘读取(支持多线程读取),这样就防止内存溢出;
2. 调用插入数据库接口也是根据分批读取的内容进行调用;
3. 分批插入数据到数据库。
四、具体实现代码
1. 流式上传文件到服务器磁盘
略,一般Java上传就可以实现,这里就不贴出。
2. 多线程分批从磁盘读取
批量读取文件:
import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import java.io.File;
import java.io.FileNotFoundException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel; /**
* 类功能描述:批量读取文件
*
* @author WangXueXing create at 19-3-14 下午6:47
* @version 1.0.0
*/
public class BatchReadFile {
private final Logger LOGGER = LoggerFactory.getLogger(BatchReadFile.class);
/**
* 字符集UTF-8
*/
public static final String CHARSET_UTF8 = "UTF-8";
/**
* 字符集GBK
*/
public static final String CHARSET_GBK = "GBK";
/**
* 字符集gb2312
*/
public static final String CHARSET_GB2312 = "gb2312";
/**
* 文件内容分割符-逗号
*/
public static final String SEPARATOR_COMMA = ","; private int bufSize = 1024;
// 换行符
private byte key = "\n".getBytes()[0];
// 当前行数
private long lineNum = 0;
// 文件编码,默认为gb2312
private String encode = CHARSET_GB2312;
// 具体业务逻辑监听器
private ReaderFileListener readerListener; public void setEncode(String encode) {
this.encode = encode;
} public void setReaderListener(ReaderFileListener readerListener) {
this.readerListener = readerListener;
} /**
* 获取准确开始位置
* @param file
* @param position
* @return
* @throws Exception
*/
public long getStartNum(File file, long position) throws Exception {
long startNum = position;
FileChannel fcin = new RandomAccessFile(file, "r").getChannel();
fcin.position(position);
try {
int cache = 1024;
ByteBuffer rBuffer = ByteBuffer.allocate(cache);
// 每次读取的内容
byte[] bs = new byte[cache];
// 缓存
byte[] tempBs = new byte[0];
while (fcin.read(rBuffer) != -1) {
int rSize = rBuffer.position();
rBuffer.rewind();
rBuffer.get(bs);
rBuffer.clear();
byte[] newStrByte = bs;
// 如果发现有上次未读完的缓存,则将它加到当前读取的内容前面
if (null != tempBs) {
int tL = tempBs.length;
newStrByte = new byte[rSize + tL];
System.arraycopy(tempBs, 0, newStrByte, 0, tL);
System.arraycopy(bs, 0, newStrByte, tL, rSize);
}
// 获取开始位置之后的第一个换行符
int endIndex = indexOf(newStrByte, 0);
if (endIndex != -1) {
return startNum + endIndex;
}
tempBs = substring(newStrByte, 0, newStrByte.length);
startNum += 1024;
}
} finally {
fcin.close();
}
return position;
} /**
* 从设置的开始位置读取文件,一直到结束为止。如果 end设置为负数,刚读取到文件末尾
* @param fullPath
* @param start
* @param end
* @throws Exception
*/
public void readFileByLine(String fullPath, long start, long end) throws Exception {
File fin = new File(fullPath);
if (!fin.exists()) {
throw new FileNotFoundException("没有找到文件:" + fullPath);
}
FileChannel fileChannel = new RandomAccessFile(fin, "r").getChannel();
fileChannel.position(start);
try {
ByteBuffer rBuffer = ByteBuffer.allocate(bufSize);
// 每次读取的内容
byte[] bs = new byte[bufSize];
// 缓存
byte[] tempBs = new byte[0];
String line;
// 当前读取文件位置
long nowCur = start;
while (fileChannel.read(rBuffer) != -1) {
int rSize = rBuffer.position();
rBuffer.rewind();
rBuffer.get(bs);
rBuffer.clear();
byte[] newStrByte;
//去掉表头
if(nowCur == start){
int firstLineIndex = indexOf(bs, 0);
int newByteLenth = bs.length-firstLineIndex-1;
newStrByte = new byte[newByteLenth];
System.arraycopy(bs, firstLineIndex+1, newStrByte, 0, newByteLenth);
} else {
newStrByte = bs;
} // 如果发现有上次未读完的缓存,则将它加到当前读取的内容前面
if (null != tempBs && tempBs.length != 0) {
int tL = tempBs.length;
newStrByte = new byte[rSize + tL];
System.arraycopy(tempBs, 0, newStrByte, 0, tL);
System.arraycopy(bs, 0, newStrByte, tL, rSize);
}
// 是否已经读到最后一位
boolean isEnd = false;
nowCur += bufSize;
// 如果当前读取的位数已经比设置的结束位置大的时候,将读取的内容截取到设置的结束位置
if (end > 0 && nowCur > end) {
// 缓存长度 - 当前已经读取位数 - 最后位数
int l = newStrByte.length - (int) (nowCur - end);
newStrByte = substring(newStrByte, 0, l);
isEnd = true;
}
int fromIndex = 0;
int endIndex = 0;
// 每次读一行内容,以 key(默认为\n) 作为结束符
while ((endIndex = indexOf(newStrByte, fromIndex)) != -1) {
byte[] bLine = substring(newStrByte, fromIndex, endIndex);
line = new String(bLine, 0, bLine.length, encode);
lineNum++;
// 输出一行内容,处理方式由调用方提供
readerListener.outLine(line.trim(), lineNum, false);
fromIndex = endIndex + 1;
}
// 将未读取完成的内容放到缓存中
tempBs = substring(newStrByte, fromIndex, newStrByte.length);
if (isEnd) {
break;
}
}
// 将剩下的最后内容作为一行,输出,并指明这是最后一行
String lineStr = new String(tempBs, 0, tempBs.length, encode);
readerListener.outLine(lineStr.trim(), lineNum, true);
} finally {
fileChannel.close();
fin.deleteOnExit();
}
} /**
* 查找一个byte[]从指定位置之后的一个换行符位置
*
* @param src
* @param fromIndex
* @return
* @throws Exception
*/
private int indexOf(byte[] src, int fromIndex) throws Exception {
for (int i = fromIndex; i < src.length; i++) {
if (src[i] == key) {
return i;
}
}
return -1;
} /**
* 从指定开始位置读取一个byte[]直到指定结束位置为止生成一个全新的byte[]
*
* @param src
* @param fromIndex
* @param endIndex
* @return
* @throws Exception
*/
private byte[] substring(byte[] src, int fromIndex, int endIndex) throws Exception {
int size = endIndex - fromIndex;
byte[] ret = new byte[size];
System.arraycopy(src, fromIndex, ret, 0, size);
return ret;
}
}
以上是关键代码:利用FileChannel与ByteBuffer从磁盘中分批读取数据
多线程调用批量读取:
/**
* 类功能描述: 线程读取文件
*
* @author WangXueXing create at 19-3-14 下午6:51
* @version 1.0.0
*/
public class ReadFileThread extends Thread {
private ReaderFileListener processDataListeners;
private String filePath;
private long start;
private long end;
private Thread preThread; public ReadFileThread(ReaderFileListener processDataListeners,
long start,long end,
String file) {
this(processDataListeners, start, end, file, null);
} public ReadFileThread(ReaderFileListener processDataListeners,
long start,long end,
String file,
Thread preThread) {
this.setName(this.getName()+"-ReadFileThread");
this.start = start;
this.end = end;
this.filePath = file;
this.processDataListeners = processDataListeners;
this.preThread = preThread;
} @Override
public void run() {
BatchReadFile readFile = new BatchReadFile();
readFile.setReaderListener(processDataListeners);
readFile.setEncode(processDataListeners.getEncode());
try {
readFile.readFileByLine(filePath, start, end + 1);
if(this.preThread != null){
this.preThread.join();
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
监听读取:
import java.util.ArrayList;
import java.util.List; /**
* 类功能描述:读文件监听父类
*
* @author WangXueXing create at 19-3-14 下午6:52
* @version 1.0.0
*/
public abstract class ReaderFileListener<T> {
// 一次读取行数,默认为1000
private int readColNum = 1000; /**
* 文件编码
*/
private String encode; /**
* 分批读取行列表
*/
private List<String> rowList = new ArrayList<>(); /**
*其他参数
*/
private T otherParams; /**
* 每读取到一行数据,添加到缓存中
* @param lineStr 读取到的数据
* @param lineNum 行号
* @param over 是否读取完成
* @throws Exception
*/
public void outLine(String lineStr, long lineNum, boolean over) throws Exception {
if(null != lineStr && !lineStr.trim().equals("")){
rowList.add(lineStr);
} if (!over && (lineNum % readColNum == 0)) {
output(rowList);
rowList = new ArrayList<>();
} else if (over) {
output(rowList);
rowList = new ArrayList<>();
}
} /**
* 批量输出
*
* @param stringList
* @throws Exception
*/
public abstract void output(List<String> stringList) throws Exception; /**
* 设置一次读取行数
* @param readColNum
*/
protected void setReadColNum(int readColNum) {
this.readColNum = readColNum;
} public String getEncode() {
return encode;
} public void setEncode(String encode) {
this.encode = encode;
} public T getOtherParams() {
return otherParams;
} public void setOtherParams(T otherParams) {
this.otherParams = otherParams;
} public List<String> getRowList() {
return rowList;
} public void setRowList(List<String> rowList) {
this.rowList = rowList;
}
}
实现监听读取并分批调用插入数据接口:
import com.today.api.finance.ImportServiceClient;
import com.today.api.finance.request.ImportRequest;
import com.today.api.finance.response.ImportResponse;
import com.today.api.finance.service.ImportService;
import com.today.common.Constants;
import com.today.domain.StaffSimpInfo;
import com.today.util.EmailUtil;
import com.today.util.UserSessionHelper;
import com.today.util.readfile.ReadFile;
import com.today.util.readfile.ReadFileThread;
import com.today.util.readfile.ReaderFileListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils; import java.io.File;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.FutureTask;
import java.util.stream.Collectors; /**
* 类功能描述:报表导入服务实现
*
* @author WangXueXing create at 19-3-19 下午1:43
* @version 1.0.0
*/
@Service
public class ImportReportServiceImpl extends ReaderFileListener<ImportRequest> {
private final Logger LOGGER = LoggerFactory.getLogger(ImportReportServiceImpl.class);
@Value("${READ_COL_NUM_ONCE}")
private String readColNum;
@Value("${REPORT_IMPORT_RECEIVER}")
private String reportImportReceiver;
/**
* 财务报表导入接口
*/
private ImportService service = new ImportServiceClient(); /**
* 读取文件内容
* @param file
*/
public void readTxt(File file, ImportRequest importRequest) throws Exception {
this.setOtherParams(importRequest);
ReadFile readFile = new ReadFile();
try(FileInputStream fis = new FileInputStream(file)){
int available = fis.available();
long maxThreadNum = 3L;
// 线程粗略开始位置
long i = available / maxThreadNum; this.setRowList(new ArrayList<>());
StaffSimpInfo staffSimpInfo = ((StaffSimpInfo)UserSessionHelper.getCurrentUserInfo().getData());
String finalReportReceiver = getEmail(staffSimpInfo.getEmail(), reportImportReceiver);
this.setReadColNum(Integer.parseInt(readColNum));
this.setEncode(ReadFile.CHARSET_GB2312);
//这里单独使用一个线程是为了当maxThreadNum大于1的时候,统一管理这些线程
new Thread(()->{
Thread preThread = null;
FutureTask futureTask = null ;
try {
for (long j = 0; j < maxThreadNum; j++) {
//计算精确开始位置
long startNum = j == 0 ? 0 : readFile.getStartNum(file, i * j);
long endNum = j + 1 < maxThreadNum ? readFile.getStartNum(file, i * (j + 1)) : -2L; //具体监听实现
preThread = new ReadFileThread(this, startNum, endNum, file.getPath(), preThread);
futureTask = new FutureTask(preThread, new Object());
futureTask.run();
}
if(futureTask.get() != null) {
EmailUtil.sendEmail(EmailUtil.REPORT_IMPORT_EMAIL_PREFIX, finalReportReceiver, "导入报表成功", "导入报表成功" ); //todo 等文案
}
} catch (Exception e){
futureTask.cancel(true);
try {
EmailUtil.sendEmail(EmailUtil.REPORT_IMPORT_EMAIL_PREFIX, finalReportReceiver, "导入报表失败", e.getMessage());
} catch (Exception e1){
//ignore
LOGGER.error("发送邮件失败", e1);
}
LOGGER.error("导入报表类型:"+importRequest.getReportType()+"失败", e);
} finally {
futureTask.cancel(true);
}
}).start();
}
} private String getEmail(String infoEmail, String reportImportReceiver){
if(StringUtils.isEmpty(infoEmail)){
return reportImportReceiver;
}
return infoEmail;
} /**
* 每批次调用导入接口
* @param stringList
* @throws Exception
*/
@Override
public void output(List<String> stringList) throws Exception {
ImportRequest importRequest = this.getOtherParams();
List<List<String>> dataList = stringList.stream()
.map(x->Arrays.asList(x.split(ReadFile.SEPARATOR_COMMA)).stream().map(String::trim).collect(Collectors.toList()))
.collect(Collectors.toList());
LOGGER.info("上传数据:{}", dataList);
importRequest.setDataList(dataList);
// LOGGER.info("request对象:{}",importRequest, "request增加请求字段:{}", importRequest.data);
ImportResponse importResponse = service.batchImport(importRequest);
LOGGER.info("===========SUCESS_CODE======="+importResponse.getCode());
//导入错误,输出错误信息
if(!Constants.SUCESS_CODE.equals(importResponse.getCode())){
LOGGER.error("导入报表类型:"+importRequest.getReportType()+"失败","返回码为:", importResponse.getCode() ,"返回信息:",importResponse.getMessage());
throw new RuntimeException("导入报表类型:"+importRequest.getReportType()+"失败"+"返回码为:"+ importResponse.getCode() +"返回信息:"+importResponse.getMessage());
}
// if(importResponse.data != null && importResponse.data.get().get("batchImportFlag")!=null) {
// LOGGER.info("eywa-service请求batchImportFlag不为空");
// }
importRequest.setData(importResponse.data); }
}
注意:
第53行代码:
long maxThreadNum = 3L;
就是设置分批读取磁盘文件的线程数,我设置为3,大家不要设置太大,不然多个线程读取到内存,也会造成服务器内存溢出。
以上所有批次的批量读取并调用插入接口都成功发送邮件通知给导入人,任何一个批次失败直接发送失败邮件。
数据库分批插入数据:
/**
* 批量插入非联机第三方导入账单
* @param dataList
*/
def insertNonOnlinePayment(dataList: List[NonOnlineSourceData]) : Unit = {
if (dataList.nonEmpty) {
CheckAccountDataSource.mysqlData.withConnection { conn =>
val sql =
s""" INSERT INTO t_pay_source_data
(store_code,
store_name,
source_date,
order_type,
trade_type,
third_party_payment_no,
business_type,
business_amount,
trade_time,
created_at,
updated_at)
VALUES (?,?,?,?,?,?,?,?,?,NOW(),NOW())""" conn.setAutoCommit(false)
var stmt = conn.prepareStatement(sql)
var i = 0
dataList.foreach { x =>
stmt.setString(1, x.storeCode)
stmt.setString(2, x.storeName)
stmt.setString(3, x.sourceDate)
stmt.setInt(4, x.orderType)
stmt.setInt(5, x.tradeType)
stmt.setString(6, x.tradeNo)
stmt.setInt(7, x.businessType)
stmt.setBigDecimal(8, x.businessAmount.underlying())
stmt.setString(9, x.tradeTime.getOrElse(null))
stmt.addBatch()
if ((i % 5000 == 0) && (i != 0)) { //分批提交
stmt.executeBatch
conn.commit
conn.setAutoCommit(false)
stmt = conn.prepareStatement(sql) }
i += 1
}
stmt.executeBatch()
conn.commit()
}
}
}
以上代码实现每5000 行提交一次批量插入,防止一次提较数据库的压力。
以上,如果大家有更好方案,请留言。