Posted on 2009-09-30 18:41
近似凯珊卓 阅读(460)
评论(0) 编辑 收藏 所属分类:
学习笔记
神经网络不会举一反三,郁闷。。。
1 //////////////////////////////////////////////////////////////////////////
2 // File Name: pnn.cpp
3 // Author: Ruoruo(du#in.tum.de)
4 //////////////////////////////////////////////////////////////////////////
5 #include "stdafx.h"
6 #include "cv.h"
7 #include "highgui.h"
8 #include <ml.h>
9 #include <time.h>
10 #include <ctype.h>
11 #include <vector>
12 #include <math.h>
13 #include <iostream>
14 using namespace std;
15
16 static CvScalar colors[] =
17 {
18 {{0,0,255}},
19 {{0,128,255}},
20 {{0,255,255}},
21 {{0,255,0}},
22 {{255,128,0}},
23 {{255,255,0}},
24 {{255,0,0}},
25 {{255,0,255}}
26 };
27
28 int main( int argc, char** argv )
29 {
30 vector<float> point;
31 vector<float> result;
32
33 float p[10] = { 1.3, 2.7,
34 1.5, 3.0,
35 1.7, 2.8,
36 1.6, 2.6,
37 1.2, 2.9 };
38 float res[5] = { 0,0,0,0,0 };
39 int i;
40 for(i=0; i<10; i++)
41 {
42 point.push_back(p[i]);
43 if(i<5) result.push_back(res[i]);
44 }
45
46 CvMat* input = cvCreateMat( 5, 2, CV_32FC1 );
47 cvInitMatHeader( input, 5, 2, CV_32FC1, p );
48 CvMat* output = cvCreateMat( 5, 1, CV_32FC1 );
49 cvInitMatHeader( output, 5, 1, CV_32FC1, res );
50 IplImage* img = cvCreateImage(cvSize(450, 450), IPL_DEPTH_8U, 3);
51 img->origin = 1;
52 for(i= 0; i<5; i++)
53 {
54 cvCircle(img, cvPoint((int)(p[i*2]*100), (int)(p[i*2+1]* 100)), 5, colors[(int)res[i]%8], 1, CV_AA, 0);
55 }
56
57 int layer_num[3] = { 2, 4, 1 };
58 CvMat* layer_size = cvCreateMatHeader( 1, 3, CV_32S );
59 cvInitMatHeader( layer_size, 1, 3, CV_32S, layer_num );
60 CvANN_MLP pnn;
61 pnn.create( layer_size, CvANN_MLP::SIGMOID_SYM, 1, 1 );
62 CvANN_MLP_TrainParams params;
63 params.term_crit = cvTermCriteria( CV_TERMCRIT_ITER | CV_TERMCRIT_EPS, \
64 300, 0.0000001 );
65 params.train_method = 0;
66 params.bp_dw_scale = 0.1;
67 params.bp_moment_scale = 0.1;
68 cout<<"begin training"<<endl;
69 pnn.train( input, output, 0, 0, params );
70 cout<<"end training"<<endl;
71 pnn.save( "pNN_DATA.xml" );
72
73 //begin to test
74 float testp[24] = { 1.4, 2.75,
75 4.3, 0.2,
76 4.2, 4.3,
77 1.45, 2.85,
78 4.2, 0.4,
79 4.1, 4.0,
80 4.3, 0.5,
81 4.0, 4.2,
82 1.5, 2.7,
83 4.1, 4.2,
84 1.6, 2.7,
85 4.0, 0.3 };
86 /*float testp[24] = { 1.4, 2.75,
87 4.3, 0.2,
88 4.2, 4.3,
89 1.45, 2.85,
90 4.25, 0.3,
91 4.25, 4.25,
92 3.5, 1.2,
93 3.0, 3.7,
94 4.0, 2.7,
95 0.2, 0.2,
96 2.8, 2.7,
97 2.7, 2.8 };*/
98 CvMat* test_point = cvCreateMat( 1, 2, CV_32FC1 );
99 CvMat* test_result = cvCreateMat( 1, 1, CV_32FC1 );
100 CvFont font;
101 double hScale=0.5;
102 double vScale=0.5;
103 int lineWidth=1;
104 cvInitFont(&font, CV_FONT_HERSHEY_COMPLEX|CV_FONT_ITALIC, hScale,vScale,0,lineWidth);
105
106 for(i= 0; i<12; i++)
107 {
108 cvSetReal2D( test_point, 0, 0, testp[2*i] );
109 cvSetReal2D( test_point, 0, 1, testp[2*i+1] );
110 pnn.predict(test_point, test_result);
111 cout<<cvmGet(test_result,0,0)<<endl;
112
113 float delta = 1;
114 int best_class = 0;
115 int max_class = 0;
116 for(int ii=0; ii<result.size(); ii++){
117 if(fabs(cvmGet(test_result,0,0) - (float)result[ii])<delta){
118 delta = fabs(cvmGet(test_result,0,0) - (float)result[ii]);
119 best_class = result[ii];
120 }
121 if(result[ii]>=max_class)
122 max_class = result[ii];
123 }
124
125 point.push_back(testp[2*i]);
126 point.push_back(testp[2*i+1]);
127
128 if( delta>0.06 ){
129 int new_result = max_class+1;
130 cvmSet( test_result,0,0,new_result );
131 result.push_back((float)new_result );
132 }
133 else{
134 cvmSet( test_result,0,0,best_class );
135 result.push_back((float)best_class );
136 }
137
138 int new_point_size = point.size();
139 int new_result_size = result.size();
140
141 CvMat* input = cvCreateMat( new_result_size, 2, CV_32FC1 );
142 CvMat* output = cvCreateMat( new_result_size, 1, CV_32FC1 );
143
144 for(int ii=0; ii<new_result_size; ii++)
145 {
146 cvmSet( input, ii, 0, point[2*ii]);
147 cvmSet( input, ii, 1, point[2*ii+1]);
148 cvmSet( output, ii, 0, result[ii]);
149 }
150 //cout<<"begin training again"<<endl;
151 pnn.train( input, output, 0, 0, params );
152 //cout<<"end training"<<endl;
153
154 cvCircle( img, cvPoint((int)(testp[i*2]*100), (int)(testp[i*2+1]* 100)), 0, colors[(int)cvmGet(test_result,0,0)%8], 10, CV_AA, 0 );
155
156 char buffer[10];
157 _itoa(i,buffer,10);
158 string point_id(buffer);
159 cvPutText(img, point_id.c_str(), cvPoint(testp[2*i]*100,testp[2*i+1]*100), &font, cvScalar(255,255,255));
160
161 cout<<i<<": "<<"("<<testp[i*2]<<", "<<testp[i*2+1]<<")"<<"\t"<<cvmGet(test_result,0,0)<<endl;
162 }
163
164 cvNamedWindow( "Coordinates" , 1 );
165 cvShowImage( "Coordinates" ,img);
166
167 cvWaitKey( 0 );
168
169 cvDestroyWindow("Coordinates");
170 cvReleaseImage(&img);
171
172 return 0;
173 }