写个神经网络的小程序

Posted on 2009-09-30 18:41 近似凯珊卓 阅读(461) 评论(0)  编辑  收藏 所属分类: 学习笔记
神经网络不会举一反三,郁闷。。。
  1 //////////////////////////////////////////////////////////////////////////
  2 // File Name: pnn.cpp
  3 // Author:    Ruoruo(du#in.tum.de)
  4 //////////////////////////////////////////////////////////////////////////
  5 #include "stdafx.h"
  6 #include "cv.h"
  7 #include "highgui.h"
  8 #include <ml.h>
  9 #include <time.h>
 10 #include <ctype.h>
 11 #include <vector>
 12 #include <math.h> 
 13 #include <iostream>
 14 using namespace std;
 15 
 16 static CvScalar colors[] = 
 17     {
 18         {{0,0,255}},
 19         {{0,128,255}},
 20         {{0,255,255}},
 21         {{0,255,0}},
 22         {{255,128,0}},
 23         {{255,255,0}},
 24         {{255,0,0}},
 25         {{255,0,255}}
 26     };
 27 
 28 int main( int argc, char** argv )
 29 {
 30     vector<float> point;
 31     vector<float> result;
 32 
 33     float p[10= { 1.32.7,
 34                     1.53.0,
 35                     1.72.8,
 36                     1.62.6,
 37                     1.22.9 };
 38     float res[5= { 0,0,0,0,0 };
 39     int i;
 40     for(i=0; i<10; i++)
 41     {
 42         point.push_back(p[i]);
 43         if(i<5) result.push_back(res[i]);
 44     }
 45 
 46     CvMat* input = cvCreateMat( 52, CV_32FC1 );
 47     cvInitMatHeader( input, 52, CV_32FC1, p );
 48     CvMat* output = cvCreateMat( 51, CV_32FC1 );
 49     cvInitMatHeader( output, 51, CV_32FC1, res );
 50     IplImage* img = cvCreateImage(cvSize(450450), IPL_DEPTH_8U, 3);
 51     img->origin = 1;
 52     for(i= 0; i<5; i++)
 53     {
 54         cvCircle(img, cvPoint((int)(p[i*2]*100), (int)(p[i*2+1]* 100)), 5, colors[(int)res[i]%8], 1, CV_AA, 0);
 55     }
 56     
 57     int layer_num[3= { 241 };
 58     CvMat* layer_size = cvCreateMatHeader( 13, CV_32S );
 59     cvInitMatHeader( layer_size, 13, CV_32S, layer_num );
 60     CvANN_MLP pnn;
 61     pnn.create( layer_size, CvANN_MLP::SIGMOID_SYM, 11 );
 62     CvANN_MLP_TrainParams params;
 63     params.term_crit = cvTermCriteria( CV_TERMCRIT_ITER | CV_TERMCRIT_EPS, \
 64         3000.0000001 );
 65     params.train_method = 0;
 66     params.bp_dw_scale = 0.1;
 67     params.bp_moment_scale = 0.1;
 68     cout<<"begin training"<<endl;
 69     pnn.train( input, output, 00params );
 70     cout<<"end training"<<endl;
 71     pnn.save( "pNN_DATA.xml" );
 72 
 73     //begin to test
 74     float testp[24= { 1.42.75,
 75                         4.30.2
 76                         4.24.3,
 77                         1.452.85,
 78                         4.20.4,
 79                         4.14.0,
 80                         4.30.5,
 81                         4.04.2,
 82                         1.52.7
 83                         4.14.2
 84                         1.62.7,
 85                         4.00.3 };
 86     /*float testp[24] = { 1.4, 2.75,
 87                         4.3, 0.2, 
 88                         4.2, 4.3,
 89                         1.45, 2.85,
 90                         4.25, 0.3,
 91                         4.25, 4.25,
 92                         3.5, 1.2,
 93                         3.0, 3.7,
 94                         4.0, 2.7, 
 95                         0.2, 0.2, 
 96                         2.8, 2.7,
 97                         2.7, 2.8 };*/
 98     CvMat* test_point = cvCreateMat( 12, CV_32FC1 );    
 99     CvMat* test_result = cvCreateMat( 11, CV_32FC1 );
100     CvFont font;
101     double hScale=0.5;
102     double vScale=0.5;
103     int lineWidth=1;
104     cvInitFont(&font, CV_FONT_HERSHEY_COMPLEX|CV_FONT_ITALIC, hScale,vScale,0,lineWidth);
105 
106     for(i= 0; i<12; i++)
107     {
108         cvSetReal2D( test_point, 00, testp[2*i] );
109         cvSetReal2D( test_point, 01, testp[2*i+1] );
110         pnn.predict(test_point, test_result);
111         cout<<cvmGet(test_result,0,0)<<endl;
112 
113         float delta = 1;
114         int best_class = 0;
115         int max_class = 0;
116         for(int ii=0; ii<result.size(); ii++){
117             if(fabs(cvmGet(test_result,0,0- (float)result[ii])<delta){
118                 delta = fabs(cvmGet(test_result,0,0- (float)result[ii]);
119                 best_class = result[ii];
120             }
121             if(result[ii]>=max_class)
122                 max_class = result[ii];
123         }
124 
125         point.push_back(testp[2*i]);
126         point.push_back(testp[2*i+1]);
127 
128         if( delta>0.06 ){
129             int new_result = max_class+1;
130             cvmSet( test_result,0,0,new_result );
131             result.push_back((float)new_result );
132         }
133         else{
134             cvmSet( test_result,0,0,best_class );
135             result.push_back((float)best_class );
136         }
137 
138         int new_point_size = point.size();
139         int new_result_size = result.size();
140             
141         CvMat* input = cvCreateMat( new_result_size, 2, CV_32FC1 );
142         CvMat* output = cvCreateMat( new_result_size, 1, CV_32FC1 );
143 
144         for(int ii=0; ii<new_result_size; ii++)
145         {
146             cvmSet( input, ii, 0, point[2*ii]);
147             cvmSet( input, ii, 1, point[2*ii+1]);
148             cvmSet( output, ii, 0, result[ii]);
149         }
150         //cout<<"begin training again"<<endl;
151         pnn.train( input, output, 00params );
152         //cout<<"end training"<<endl;
153 
154         cvCircle( img, cvPoint((int)(testp[i*2]*100), (int)(testp[i*2+1]* 100)), 0, colors[(int)cvmGet(test_result,0,0)%8], 10, CV_AA, 0 );
155 
156         char buffer[10];
157         _itoa(i,buffer,10);
158         string point_id(buffer);
159         cvPutText(img, point_id.c_str(), cvPoint(testp[2*i]*100,testp[2*i+1]*100), &font, cvScalar(255,255,255));
160 
161         cout<<i<<""<<"("<<testp[i*2]<<""<<testp[i*2+1]<<")"<<"\t"<<cvmGet(test_result,0,0)<<endl;
162     }
163 
164     cvNamedWindow( "Coordinates" , 1 ); 
165     cvShowImage( "Coordinates" ,img);
166 
167     cvWaitKey( 0 );
168 
169     cvDestroyWindow("Coordinates");
170     cvReleaseImage(&img);
171 
172     return 0;
173 }

只有注册用户登录后才能发表评论。


网站导航:
博客园   IT新闻   Chat2DB   C++博客   博问  
 

posts - 9, comments - 0, trackbacks - 0, articles - 0

Copyright © 近似凯珊卓