随笔-23  评论-58  文章-0  trackbacks-0
wikipedia上有个java版的Viterbi(维特比)实现程序(http://en.wikipedia.org/wiki/Viterbi_algorithm),但是3个观察序列会标注出4个状态序列。
下面本人写的这个Viterbi(维特比)实现程序就没这个问题,3个观察序列就只标注出3个状态序列。
public class Viterbi
{
    
public static void main(String[] args)
    
{
        String[] states 
= {"Rainy""Sunny"};
        String[] observations 
= {"walk""shop""clean"};
        
double[] start_probability = {0.60.4};
        
double[][] transition_probability = {{0.70.3}{0.40.6}};
        
double[][] emission_probability = {{0.10.40.5}{0.60.30.1}};
        forward_viterbi(observations,states,start_probability,transition_probability,emission_probability);
    }

    
    
public static void  forward_viterbi(String[] observations, String[] states,double[] start_probability, double[][] transition_probability, double[][] emission_probability)
    
{
        
int[][] path=new int[observations.length][states.length];
        
double[][] r=new double[observations.length][states.length];
        
for(int j=0;j<states.length;j++)
        
{
            r[
0][j]=start_probability[j]*emission_probability[j][0];
            path[
0][j]=0;
        }

        
        
for(int t=1;t<observations.length;t++)
        
{
            
for(int i=0;i<states.length;i++)
            
{
                
double tmp=0;int m=0;
                
for(int j=0;j<states.length;j++)
                
{
                    
double tem=r[t-1][j]*transition_probability[j][i]*emission_probability[i][t];
                    
if(tem>tmp)
                    
{
                        tmp
=tem;
                        m
=j;
                    }

                }

                r[t][i]
=tmp;
                path[t][i]
=m;
            }

        }

        
        
double p=0;int m=0;
        
for(int i=0;i<r[0].length;i++)
        
{
            
if(r[r.length-1][i]>p)
            
{
                p
=r[r.length-1][i];
                m
=i;
            }

        }

        
        System.out.println(
"p="+p);
        
        
int[] trace=new int[observations.length];
        trace[observations.length
-1]=m;
        
for(int t=observations.length-1;t>0;t--)
        
{
            trace[t
-1]=path[t][m];
            m
=path[t][m];
        }

        
        
for(int i=0;i<trace.length;i++)
            System.out.println(states[trace[i]]);
    }

}


posted on 2012-09-07 16:43 nianzai 阅读(1973) 评论(0)  编辑  收藏 所属分类: 机器学习

只有注册用户登录后才能发表评论。


网站导航: