//Copyright (c) 2020 National Institute of Advanced Industrial Science and Technology (AIST), All Rights Reserved.
//Author: Yuuji Ichisugi

/**

adrnlS̐Vbosf̓mFB

邱
EׂĂ̏d݂ƒlXC_ŎwB
d݂ [0,1], l͗UlB

E㉺̃m[h̒lw肵ēmvZ郂[hB
E w_ujdk ɂẮAlnqƂȂlftHgƂB
E w_ciud eXgp̏lݒ肵B
Ejbg̃Ӓl̓CfbNX 0 Al i ̓CfbNX i gĕ\B

 ܂ŎEeXgB
Ep^[m[h͐m[hƏ㗬m[h̗̖ʂB
d݃e[u̍\CfbNX̕tɂččlKvII

E̒l݂̂w肵ĉ̑ẘeϐ̎ӊmzeXg郂[hB

E̒lw肵ď̒l̊mz͔O̒l̑gׂďo͂Ă݂B
w̒l͑ŌvZB

EeXgpɎŏd݃f[^Ă݂B

Ed݂eXgf[^_v@\B
Julia ł̃eXgɎgB


oO
Em[hAjbg炵ĂxoXC_͏ȂB


 */

package qbc;
import lab.Lab;
import lab.Lab.LabCode;
import qbc.QBC.QBCMain;


public class BESOM4 {
    public static void main(String[] args) {
        Lab.addSelectableClass(BESOM4.class);
        System.out.println(Lab.selectableClasses + "");

        LabCode labCode = new LabCode();
        labCode.main(Test1.class);
    }
    public static class Test1 extends Lab.MainCode {
        public float sigmoidA = panel.getFloat("sigmoidA", 3, 1, 10);
        public int cNodes = panel.getInt("cNodes", 1, 0, 10);
        public int cUnits = panel.getInt("cUnits", 3, 1, 10) + 1;
        public int uNodes = panel.getInt("uNodes", 2, 1, 10);
        public int uUnits = panel.getInt("uUnits", 3, 1, 10) + 1;
        public int dNodes = panel.getInt("dNodes", 1, 1, 10);
        public int dUnits = panel.getInt("dUnits", 3, 1, 10) + 1;
        // FӒl̏d݂݂͑Ȃz͊mۂĂB
        public float[][][][] w_ujdk = new float[uNodes][uUnits][dNodes][dUnits];
        public float[][][][] w_ciud = new float[cNodes][cUnits][uNodes][dNodes];
        public float[][] x_ci = new float[cNodes][cUnits];
        public float[][] x_uj = new float[uNodes][uUnits];
        public float[][] x_dk = new float[dNodes][dUnits];
        float[][][][] g_ujdk = new float[uNodes][uUnits][dNodes][dUnits];
        float[][] s_dk = new float[dNodes][dUnits];
        
        public void main() {
            for (;;) {
                panel.speedControl("main", 100);
                // d݂̐ݒB
                setWeightsManually();
                // eϐ̒l̐ݒB
                for (int c = 0; c < cNodes; c++) {
                    setOneHotVec(x_ci[c], panel.getInt("Xc["+ c+ "]", 0, 0, cUnits-1));
                }
                for (int u = 0; u < uNodes; u++) {
                    setOneHotVec(x_uj[u], panel.getInt("Xu["+ u+ "]", 0, 0, uUnits-1));
                }
                
                if (panel.flag("Infer Xd", true)) {
                    // CPT vZ Xd ̊mz\B
                    calcCPT();
                    for (int d = 0; d < dNodes; d++) {
                        env.viewPanel.paint("Xd["+ d+ "]", x_dk[d]);
                    }
                } else {
                    float p = 1;
                    for (int d = 0; d < dNodes; d++) {
                        int k = panel.getInt("Xd["+ d+ "]", 0, 0, dUnits-1);
                        p *= x_dk[d][k];
                    }
                    // CPT vZm P(Xd,...|Xc,...,Xu,...) \B
                    env.viewPanel.print1("P(Xd,...|Xc,...,Xu,...)", ""+ p);
                }
            }
        }
        public void setOneHotVec(float[] vec, int n) {
            for (int i = 0; i < vec.length; i++) {
                vec[i] = i == n ? 1 : 0;
            } 
        }
        
        public float calcCPT() {
            // g_ujdk
            for (int u = 0; u < uNodes; u++) {
                for (int j = 1; j < uUnits; j++) {
                    for (int d = 0; d < dNodes; d++) {
                        for (int k = 1; k < dUnits; k++) {
                            float val = w_ujdk[u][j][d][k] * x_uj[u][j];
                            for (int c = 0; c < cNodes; c++) {
                                for (int i = 1; i < cUnits; i++) {
                                    val *= (1 - w_ciud[c][i][u][d] * x_ci[c][i]);
                                }
                            }
                            g_ujdk[u][j][d][k] = val;
                        }
                    }
                }
            }
            // s_dk = 1 - uj (1 - g_ujdk)
            for (int d = 0; d < dNodes; d++) {
                for (int k = 1; k < dUnits; k++) {
                    float val = 1;
                    for (int u = 0; u < uNodes; u++) {
                        for (int j = 1; j < uUnits; j++) {
                            val *= 1 - g_ujdk[u][j][d][k];
                        }
                    }
                    s_dk[d][k] = 1 - val;
                }
            }
            // x_dk
            for (int d = 0; d < dNodes; d++) {
                float total = 0;
                for (int k = 1; k < dUnits; k++) {
                    float val = s_dk[d][k];
                    for (int i = 1; i < dUnits; i++) {
                        if (i != k) {
                            val *= (1 - s_dk[d][i]);
                        }
                    }
                    x_dk[d][k] = val;
                    total += val;
                }
                x_dk[d][0] = 1 - total;
            }
            return 0;
        }
        public void setWeightsManually() {
            // w_ujdk
            for (int u = 0; u < uNodes; u++) {
                for (int j = 1; j < uUnits; j++) {
                    for (int d = 0; d < dNodes; d++) {
                        for (int k = 1; k < dUnits; k++) {
                            float defaultValue;
                            if (j == k) {
                                defaultValue = 1;
                            } else {
                                defaultValue = 0;
                            }
                            w_ujdk[u][j][d][k] =
                                    panel.getFloat("w_ujdk["+ u+ j+ d+ k+ "]",
                                            defaultValue, 0, 1);
                        }
                    }
                }
            }
            // w_ciud
            for (int c = 0; c < cNodes; c++) {
                for (int i = 1; i < cUnits; i++) {
                    for (int u = 0; u < uNodes; u++) {
                        for (int d = 0; d < dNodes; d++) {
                            float defaultValue;
                            // eXgpɃQ[gp^[荞݁B
                            if (i == 1 && u == 0 ||
                                    i == 2 && u == 1 ||
                                    i == 3) {
                                defaultValue = 1;
                            } else {
                                defaultValue = 0;
                            }
                            w_ciud[c][i][u][d] =
                                    panel.getFloat("w_ciud["+ c+ i+ u+ d+ "]",
                                            defaultValue, 0, 1);
                        }
                    }
                }
            }
        }
        
    }
}
