tensorflow+java 内存泄漏修复

前段时间java程序,内存泄漏比较严重,平均3-5天就要重启一下,赶快分析原因。从公司的监控看到,主要是对外内存泄漏,因为堆内存泄漏不是很严重。所以决定优先处理前者。因为该项目是java开发的,主要任务时加载tensorflow1.*的模型,并实时预测。其实主要JNI调用c++接口,所以很大可能是在接口调用时泄漏了,看代码:

	Tensor res =null;
    	try {
			List<String> colname=IntColname;
			Runner rlt = sess.runner();
			for(int i=0; i<intvalue[0].length; i++) {
				int[][] index=new int[intvalue.length][1];
				for(int m=0; m<intvalue.length; m++) {
					index[m][0]=intvalue[m][i];		    		
				}
				Tensor indexTensor = Tensor.create(index); 
				rlt.feed(colname.get(i), indexTensor);
			}
			
			colname=FloatColname;
			
			for(int i=0; i<floatvalue[0].length; i++) {
				float[][] index=new float[floatvalue.length][1];
				for(int m=0; m<floatvalue.length; m++) {
					index[m][0]=floatvalue[m][i];	   		
				}
				Tensor indexTensor = Tensor.create(index); 
				rlt.feed(colname.get(i), indexTensor);
				temp.add(indexTensor);
			}
			 	
		    res=rlt.fetch("output").run().get(0);
			float[][] finalRlt = new float[intvalue.length][1];
			res.copyTo(finalRlt);
			
			List<Float> result=new ArrayList<Float>();
			for(int i=0; i<finalRlt.length; i++) {
				result.add(finalRlt[i][0]);
			}
			return result;
		} catch (Exception e) {
			logger.error("",e);
		}finally {
			if (res != null) {
                 res.close();
            }
		}


虽然res调用了close方法,但是 indexTensor 却没有调用,因此调整代码。

	Tensor res =null;
    	List<Tensor> temp=new ArrayList<>();
    	try {
			List<String> colname=IntColname;
			Runner rlt = sess.runner();
			for(int i=0; i<intvalue[0].length; i++) {
				int[][] index=new int[intvalue.length][1];
				for(int m=0; m<intvalue.length; m++) {
					index[m][0]=intvalue[m][i];		    		
				}
				Tensor indexTensor = Tensor.create(index); 
				rlt.feed(colname.get(i), indexTensor);
				temp.add(indexTensor);
			}
			
			colname=FloatColname;
			
			for(int i=0; i<floatvalue[0].length; i++) {
				float[][] index=new float[floatvalue.length][1];
				for(int m=0; m<floatvalue.length; m++) {
					index[m][0]=floatvalue[m][i];	   		
				}
				Tensor indexTensor = Tensor.create(index); 
				rlt.feed(colname.get(i), indexTensor);
				temp.add(indexTensor);
			}
			 	
		    res=rlt.fetch("output").run().get(0);
			float[][] finalRlt = new float[intvalue.length][1];
			res.copyTo(finalRlt);
			
			List<Float> result=new ArrayList<Float>();
			for(int i=0; i<finalRlt.length; i++) {
				result.add(finalRlt[i][0]);
			}
			return result;
		} catch (Exception e) {
			logger.error("",e);
		}finally {
			if (res != null) {
            	res.close();
            }
			for(int i=0; i<temp.size(); i++) {
				Tensor t=temp.get(i);
				if(t!=null) {
					t.close();
				}
			}
		}


主要是增加temp队列回收临时变量。上线后发现泄漏不是很严重了,说明很有效果。不过还有,后面是java堆内存泄漏,发现每次泄漏的时间正是模型切换的时间,因此大概率是模型切换的代码有问题,上代码。

public static void clearCache() {
		logger.info("===model clear===");
		sess=null;
	}

这里的session就是模型解析好的session会话,由下面的代码生成。

byte[] graphBytes = IOUtils.toByteArray(inputStream);
Graph graph=new Graph();
graph.importGraphDef(graphBytes);    
sess = new Session(graph);

立刻明白,上面的代码有问题,调整为如下:

public static void clearCache() {
		logger.info("===model clear===");
		if(sess!=null) {
			logger.info("destory session");
			sess.close();
			sess=null;
		}
	}

少了一个close,造成session关联的对象无法释放,至此内存泄漏问题算是解决了。

上一篇:tensorflow 迭代周期长,每个epoch时间变慢


下一篇:TensorFlow笔记-变量,图,会话