`
slowman
  • 浏览: 37737 次
  • 性别: Icon_minigender_1
  • 来自: 武汉
社区版块
存档分类
最新评论

用JAVA进行神经网络建模及泛化能力测试

    博客分类:
  • AI
阅读更多

 

作者:桂子山下一棵草   email: slowguy@qq.com 

 

 

题目:

                     表一澳大利亚野兔眼睛晶状体重量与年龄的对应关系

 

 

编号

年龄()

重量(mg)

年龄()

重量(mg)

年龄()

重量(mg)

年龄()

重量(mg)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

15

15

15

18

28

29

37

37

44

50

50

60

61

64

65

65

72

75

21.66

22.75

22.3

31.25

44.79

40.55

50.25

46.88

52.03

63.47

61.13

81

73.09

79.09

79.51

65.31

71.9

86.1

75

82

85

91

91

97

98

125

142

142

147

147

150

159

165

183

192

195

94.6

92.5

105

101.7

102.9

110

104.3

134.9

130.68

140.58

155.3

152.2

144.5

142.15

139.81

153.22

145.72

161.1

218

218

219

224

225

227

232

232

237

246

258

276

285

300

301

305

312

317

174.18

173.03

173.54

178.86

177.68

173.73

159.98

161.29

187.07

176.13

183.4

186.26

189.66

186.09

186.7

186.8

195.1

216.41

338

347

354

357

375

394

513

535

554

591

648

660

705

723

756

768

860

203.23

188.38

189.7

195.31

202.63

224.82

203.3

209.7

233.9

234.7

244.3

231

242.4

230.77

242.57

232.12

246.7

 

澳大利亚野兔眼睛晶状体的重量为年龄的函数。利用BP算法,设计一个多层感知器,为表中的数据集提供一个非线性逼近,并测试其泛化能力。

算法源码:

 

package com.lwm.cn.althom;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.GregorianCalendar;
import java.util.Random;

public class BackProp {
	private int randomPrecision = 8; // 生成double型随机数的精度,默认为6位小数

	private int input_dimension; // 输入向量的维度

	private int output_dimension; // 输出向量的维数

	private int mid_dimension; // 隐层结点的个数

	private double[][] V; // 输入层到隐层的权值矩阵

	private double[][] W; // 隐层到输出层的权值矩阵

	private double[] inputArray; // 输入层向量

	private double[] midArray; // 隐层输出向量

	private double[] outputArray; // 输出层向量

	private double[] teacherArray; // 期望层向量

	private double mid_Threshold; // 隐层阈值

	private double out_Threshold; // 输出层阈值

	private double[] midError; // 隐层的误差

	private double[] outError; // 输出层的误差

	private double totalError = 0.0;

	private double outPrecision; // 要达到的精度

	private double learnRate; // 学习的速率

	private int trainTotal = 3000; // 学习1000次

	private boolean isQualify = false; // 用于判断是不是达到精度要求

	private ArrayList<SampleNode> trainArray = new ArrayList<SampleNode>(100); // 存入训练集

	private ArrayList<SampleNode> testArray = new ArrayList<SampleNode>(100); // 存放测试集

	private BufferedWriter bw = null; // 用于将学习和测试过程写于文件

	Date startTime;

	// SampleNode sample;

	// Math.random()
	public BackProp(double[][] v, double[][] w, int input_dimension,
			int output_dimension, int mid_dimension) {
		super();
		V = v;
		W = w;
		this.input_dimension = input_dimension;
		this.output_dimension = output_dimension;
		this.mid_dimension = mid_dimension;
	}

	/**
	 * 默认构造函数 ,对于本次实验,输入向量只有一个,输出也只有一个. 隐层结点的个数默认为4
	 * 
	 */
	public BackProp() {
		input_dimension = 1;
		output_dimension = 1;
		mid_dimension = 8;

		inputArray = new double[input_dimension];
		teacherArray = new double[output_dimension];
		midArray = new double[mid_dimension];
		outputArray = new double[output_dimension];

		V = new double[input_dimension][mid_dimension];
		W = new double[mid_dimension][input_dimension];

		midError = new double[mid_dimension];
		outError = new double[output_dimension];
	}

	/**
	 * 初始化函数 我认为一个完整的BP算法应该具备通用性,可以任意设置输入结点个数和隐层的层数及每一层的结点个数
	 * 初始化权值矩阵V和W,每个元素的值均为0-1之间的六位小数
	 */

	public void init()
	{
		// 记录程序开始时间及结束时间,以开始时间命名一个文件,用来保存学习和测试结果.
		startTime = new Date();
		SimpleDateFormat sdf = new SimpleDateFormat("yyyy年MM月dd日HH时mm分ss秒");
		String timeStr = sdf.format(startTime);
		String filePathName = "E:" + File.separator + timeStr + ".txt";
		try
		{
			bw = new BufferedWriter(new FileWriter(filePathName));
			bw.write("程序开始时间:" + timeStr + "\n");
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		mid_Threshold = MathExtend.round(Math.random(), randomPrecision); // 初始化隐层的阈值
		out_Threshold = MathExtend.round(Math.random(), randomPrecision); // 初始化输出层的阈值
		// 初始化V矩阵
		for (int i = 0; i < input_dimension; i++)
			for (int j = 0; j < mid_dimension; j++)
				V[i][j] = MathExtend.round(Math.random(), randomPrecision);

		// 初始化W矩阵
		for (int i = 0; i < mid_dimension; i++)
			for (int j = 0; j < output_dimension; j++)
				W[i][j] = MathExtend.round(Math.random(), randomPrecision);

		// 置总的误差为0,学习率为0-1之间的小数,网络训练后达到的精度为一正小数
		totalError = 0.0;
		learnRate = MathExtend.round(Math.random(), randomPrecision);
		// learnRate = 0.12;
		outPrecision = MathExtend.round(Math.random(), randomPrecision);

		try
		{
			StringBuilder sb = new StringBuilder();
			sb.append("本次实验随机生成的学习率: " + learnRate);
			sb.append("\n");
			sb.append("期望达到的精度为: " + outPrecision);
			sb.append("\n");
			bw.write(sb.toString());
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

		getTrainData(); // 取得训练集
		getTestData(); // 取得测试集
		normalized(); // 归一化
	}

	/**
	 * @author Administrator 输入层向隐层,隐层向输出层的传播
	 * 
	 */
	public void finish()
	{
		// Date endDate = new Date() ;

		try
		{
			bw.close();
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	public void forword()
	{
		int i, j;
		double temp_sum ; // 用于向量的内积
		// 输出层到隐层
		for (i = 0; i < mid_dimension; i++)
		{
			temp_sum = 0.0  ; //初始化为0
			for (j = 0; j < input_dimension; j++)
				temp_sum += V[j][i] * inputArray[j];
			temp_sum = temp_sum - mid_Threshold;
			midArray[i] = 1.0 / (1 + Math.exp(-temp_sum));
		}

		
		// 隐层到输出层
		for (i = 0; i < output_dimension; i++)
		{
			temp_sum = 0.0; // 初始化
			for (j = 0; j < mid_dimension; j++)
				temp_sum = W[j][i] * midArray[j];
			temp_sum = temp_sum - out_Threshold;
			outputArray[i] = 1.0 / (1 + Math.exp(-temp_sum));
		}
		// 计算误差,累加起来,
		temp_sum = 0.0;
		for (i = 0; i < output_dimension; i++)
		{
			temp_sum = teacherArray[i] - outputArray[i]; // 注意中,本设计中output_dimension=1的
			totalError += temp_sum * temp_sum / 2;
		}
		// printResult();
	}

	private void printResult()
	{
		/*
		 * StringBuilder sb = new StringBuilder() ;
		 * sb.append("输入数据:"+inputArray[0]); sb.append("
		 * 实际输出数据:"+outputArray[0]); sb.append(" 期望输出数据为:"+teacherArray[0]) ;
		 * sb.append("\\n") ; try { bw.write(sb.toString()); } catch
		 * (IOException e) { // TODO Auto-generated catch block
		 * e.printStackTrace(); }
		 */
		System.out.print("输入数据:" + inputArray[0]);
		System.out.print("   实际输出数据:" + outputArray[0]);
		System.out.println("   期望输出数据为:" + teacherArray[0]);
	}

	/**
	 * 反向调整权值矩阵
	 */
	public void adjustWeight()
	{
		double temp_sum = 0.0;
		int i, j;
		// 计算各层的误差信号  输出层
		for (i = 0; i < output_dimension; i++)
		{
			outError[i] = (teacherArray[i] - outputArray[i])
					* (1 - outputArray[i]) * outputArray[i];
		}
//    隐层误差
		for (i = 0; i < mid_dimension; i++)
		{
			temp_sum=0.0d ;
			for (j = 0; j < output_dimension; j++)
				temp_sum += outError[j] * W[i][j];
			midError[i] = temp_sum * (1 - midArray[i]) * midArray[i];
		}

		// 调整W权值矩阵
		for (i = 0; i < mid_dimension; i++)
		{
			for (j = 0; j < output_dimension; j++)
				W[i][j] += learnRate * outError[j] * midArray[i];
		}
		// 调整V权值矩阵

		for (i = 0; i < input_dimension; i++)
			for (j = 0; j < mid_dimension; j++)
				V[i][j] += learnRate * midError[j] * inputArray[i];

	}

	public void getTrainData()
	{
		String filePathName = "E:" + File.separator + "traindata.txt";
		BufferedReader br = null;
		try
		{
			br = new BufferedReader(new FileReader(filePathName));
		} catch (FileNotFoundException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		String s = null;
		SampleNode sNode = null;
		try
		{
			while ((s = br.readLine()) != null)
			{
				String data[] = s.trim().split("[\\s]+");
				if (data == null || data.length != 2)
				{

					System.out.println("traindata文件数据有问题!");
					return;
				}
				double in = Double.parseDouble(data[0]);
				double hope = Double.parseDouble(data[1]);
				sNode = new SampleNode(in, hope);
				trainArray.add(sNode);

				// trainArray.
			}
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		trainArray.trimToSize();
	}

	public void getTestData()
	{
		String fileName = "E:" + File.separator + "testdata.txt";
		BufferedReader br = null;
		try
		{
			br = new BufferedReader(new FileReader(fileName));
		} catch (FileNotFoundException e)
		{
			// TODO Auto-generated catch block
			System.out.println("testdata.txt文件不存在");
			e.printStackTrace();
		}

		String s = null;
		SampleNode sNode = null;
		try
		{
			while ((s = br.readLine()) != null)
			{
				String data[] = s.trim().split("[\\s]+");
				if (data == null || data.length != 2)
				{

					System.out.println("testdata文件数据有问题!");
					return;
				}
				double in = Double.parseDouble(data[0]);
				double hope = Double.parseDouble(data[1]);
				sNode = new SampleNode(in, hope);
				testArray.add(sNode);

				// trainArray.
			}
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		testArray.trimToSize();

	}

	/**
	 * 对输入数据进行归一化处理,将输入数据限制在[0,1]区间内
	 * 
	 */

	private void normalized()
	{
		if (trainArray == null || trainArray.size() == 0 || testArray == null
				|| testArray.size() == 0)
		{
			System.out.println("测试数据或者训练数据有问题!");
			return;
		}
		SampleNode sNode = null;
		// 训练数据归一化
		int size = trainArray.size();
		int i = 0;
		while (i < size)
		{
			sNode = trainArray.get(i);
			double in = sNode.in;
			double hope = sNode.hope;
			in /= 1000.0; // 归一
			hope /= 250.0;
			sNode.in = in;
			sNode.hope = hope;
			trainArray.set(i, sNode);
			i++;
		}

		size = testArray.size();
		i = 0;
		// 测试数据归一化
		while (i < size)
		{
			sNode = testArray.get(i);
			double in = sNode.in;
			double hope = sNode.hope;
			in /= 1000.0; // 归一
			hope /= 250.0;
			sNode.in = in;
			sNode.hope = hope;
			trainArray.set(i, sNode);
			i++;
		}

	}

	public void startTrain()
	{
		if (trainArray == null || trainArray.size() == 0)
			return;
		System.out.println("训练开始");
		System.out.println("当前学习速率:" + learnRate);
		System.out.println("期望精度为:" + outPrecision);
		int trainConunter = 0;
		while (trainConunter++ < trainTotal)
		{
			System.out.println("第" + trainConunter + "次训练开始:");
			for (SampleNode sNode : trainArray)
			{
				// 说明:在本设计中inputArray,和teacherArray虽然都是数组,但均只有一个元素.
				// 本人为了综合虑,才将设为数组的.
				inputArray[0] = sNode.in;
				teacherArray[0] = sNode.hope;
				forword(); // 学习一次
				printResult();
			} // 至此,所有训练集全部学习完毕,下面应该进行权值调整.
			/*System.out.println("此次学习后,总的误差为:" + totalError);
			StringBuilder sb = new StringBuilder();
			sb.append("第" + trainConunter);
			sb.append("次学习后,总的误差为:" + totalError);
			sb.append("\n");*/
			try
			{
			//	bw.write(sb.toString());
				bw.write(Double.toString(totalError)+"\n") ;
			} catch (IOException e)
			{
				e.printStackTrace();
			}
			adjustWeight(); // 集体主义原则来调整权值

			if (totalError <= outPrecision)
			{
				isQualify = true; // 置标志位为真,表示达到要求

				break;
			}
			totalError = 0.0; // 误差初化
		}

		Date endTime = new Date();
		SimpleDateFormat sdf = new SimpleDateFormat("yyyy年MM月dd日HH时mm分ss秒");
		String endtimeStr = sdf.format(endTime);
		long gap = endTime.getTime() - this.startTime.getTime();
		StringBuilder sb = new StringBuilder();
		try
		{
			sb.append("训练结束时间为:" + endtimeStr);
			sb.append("\n");
			sb.append("总的学习时间为:" + gap);
			sb.append("微秒\n");
			sb.append("********************************************\n");
			bw.write(sb.toString());
		} catch (IOException e1)
		{
			// TODO Auto-generated catch block
			e1.printStackTrace();
		}

		if (!isQualify)
		{
			System.out.println("达到训练次数,训练结束!");
			try
			{
				bw.write("训练次数:" + trainTotal + "次\n");
			} catch (IOException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		} else
		{
			try
			{
				bw.write("达到精度要求,学习完毕!\n");
			} catch (IOException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			System.out.println("达到要求的精度,训练结束!");
		}
	}

	public void startTest()
	{

		if (testArray == null || testArray.isEmpty() == true)
			return;
		
		for (SampleNode sNode : testArray)
		{
			StringBuilder sb = new StringBuilder();
			inputArray[0] = sNode.in;
			teacherArray[0] = sNode.hope;
			forword();
			sb.append("输入测试数据: " + inputArray[0]);
			sb.append("   实际输出:" + outputArray[0]);
			sb.append("   期望输出:" + teacherArray[0]);
			sb.append("\n");
			try
			{
				bw.write(sb.toString());
			} catch (IOException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			printResult();
		}
	}

}

 

 

 

测试输出结果如下图:

 

 

 

程序运行一次的收敛图如下图:



 

<!--EndFragment-->

  • 大小: 187.8 KB
  • 大小: 3.8 KB
1
0
分享到:
评论
2 楼 slowman 2012-01-18  
是的,不过这只是实现了算法的最基本要求。
1 楼 loveq369 2011-04-11  
自己写的吗?厉害哦。

相关推荐

Global site tag (gtag.js) - Google Analytics