package com.hankcs.hanlp.model.crf.crfpp;

import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.model.crf.crfpp.TaggerImpl;
import com.sun.xml.bind.v2.runtime.reflect.opt.Const;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/* loaded from: input_file:WEB-INF/lib/hanlp-portable-1.6.8.jar:com/hankcs/hanlp/model/crf/crfpp/Encoder.class */
public class Encoder {
    public static int MODEL_VERSION = 100;

    /* loaded from: input_file:WEB-INF/lib/hanlp-portable-1.6.8.jar:com/hankcs/hanlp/model/crf/crfpp/Encoder$Algorithm.class */
    public enum Algorithm {
        CRF_L2,
        CRF_L1,
        MIRA;

        public static Algorithm fromString(String str) {
            String lowerCase = str.toLowerCase();
            if (lowerCase.equals("crf") || lowerCase.equals("crf-l2")) {
                return CRF_L2;
            }
            if (lowerCase.equals("crf-l1")) {
                return CRF_L1;
            }
            if (lowerCase.equals("mira")) {
                return MIRA;
            }
            throw new IllegalArgumentException("invalid algorithm: " + lowerCase);
        }
    }

    public boolean learn(String str, String str2, String str3, boolean z, int i, int i2, double d, double d2, int i3, int i4, Algorithm algorithm) {
        if (d <= Const.default_value_double) {
            System.err.println("eta must be > 0.0");
            return false;
        }
        if (d2 < Const.default_value_double) {
            System.err.println("C must be >= 0.0");
            return false;
        }
        if (i4 < 1) {
            System.err.println("shrinkingSize must be >= 1");
            return false;
        }
        if (i3 <= 0) {
            System.err.println("thread must be  > 0");
            return false;
        }
        EncoderFeatureIndex encoderFeatureIndex = new EncoderFeatureIndex(i3);
        List<TaggerImpl> arrayList = new ArrayList<>();
        if (!encoderFeatureIndex.open(str, str2)) {
            System.err.println("Fail to open " + str + " " + str2);
        }
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(str2), "UTF-8"));
            int i5 = 0;
            while (true) {
                TaggerImpl taggerImpl = new TaggerImpl(TaggerImpl.Mode.LEARN);
                taggerImpl.open(encoderFeatureIndex);
                TaggerImpl.ReadStatus read = taggerImpl.read(bufferedReader);
                if (read == TaggerImpl.ReadStatus.ERROR) {
                    System.err.println("error when reading " + str2);
                    return false;
                }
                if (taggerImpl.empty()) {
                    if (read == TaggerImpl.ReadStatus.EOF) {
                        bufferedReader.close();
                        encoderFeatureIndex.shrink(i2, arrayList);
                        double[] dArr = new double[encoderFeatureIndex.size()];
                        Arrays.fill(dArr, Const.default_value_double);
                        encoderFeatureIndex.setAlpha_(dArr);
                        System.out.println("Number of sentences: " + arrayList.size());
                        System.out.println("Number of features:  " + encoderFeatureIndex.size());
                        System.out.println("Number of thread(s): " + i3);
                        System.out.println("Freq:                " + i2);
                        System.out.println("eta:                 " + d);
                        System.out.println("C:                   " + d2);
                        System.out.println("shrinking size:      " + i4);
                        switch (algorithm) {
                            case CRF_L1:
                                if (!runCRF(arrayList, encoderFeatureIndex, dArr, i, d2, d, i4, i3, true)) {
                                    System.err.println("CRF_L1 execute error");
                                    return false;
                                }
                                break;
                            case CRF_L2:
                                if (!runCRF(arrayList, encoderFeatureIndex, dArr, i, d2, d, i4, i3, false)) {
                                    System.err.println("CRF_L2 execute error");
                                    return false;
                                }
                                break;
                            case MIRA:
                                if (!runMIRA(arrayList, encoderFeatureIndex, dArr, i, d2, d, i4, i3)) {
                                    System.err.println("MIRA execute error");
                                    return false;
                                }
                                break;
                        }
                        if (!encoderFeatureIndex.save(str3, z)) {
                            System.err.println("Failed to save model");
                        }
                        System.out.println("Done!");
                        return true;
                    }
                } else {
                    if (!taggerImpl.shrink()) {
                        System.err.println("fail to build feature index ");
                        return false;
                    }
                    taggerImpl.setThread_id_(i5 % i3);
                    arrayList.add(taggerImpl);
                    i5++;
                    if (i5 % 100 == 0) {
                        System.out.print(i5 + ".. ");
                    }
                }
            }
        } catch (IOException e) {
            System.err.println("train file " + str2 + " does not exist.");
            return false;
        }
    }

    private boolean runCRF(List<TaggerImpl> list, EncoderFeatureIndex encoderFeatureIndex, double[] dArr, int i, double d, double d2, int i2, int i3, boolean z) {
        double d3 = 1.0E37d;
        int i4 = 0;
        LbfgsOptimizer lbfgsOptimizer = new LbfgsOptimizer();
        ArrayList arrayList = new ArrayList();
        for (int i5 = 0; i5 < i3; i5++) {
            CRFEncoderThread cRFEncoderThread = new CRFEncoderThread(dArr.length);
            cRFEncoderThread.start_i = i5;
            cRFEncoderThread.size = list.size();
            cRFEncoderThread.threadNum = i3;
            cRFEncoderThread.x = list;
            arrayList.add(cRFEncoderThread);
        }
        int i6 = 0;
        for (int i7 = 0; i7 < list.size(); i7++) {
            i6 += list.get(i7).size();
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i3);
        int i8 = 0;
        while (i8 < i) {
            encoderFeatureIndex.clear();
            try {
                newFixedThreadPool.invokeAll(arrayList);
                for (int i9 = 1; i9 < i3; i9++) {
                    ((CRFEncoderThread) arrayList.get(0)).obj += ((CRFEncoderThread) arrayList.get(i9)).obj;
                    ((CRFEncoderThread) arrayList.get(0)).err += ((CRFEncoderThread) arrayList.get(i9)).err;
                    ((CRFEncoderThread) arrayList.get(0)).zeroone += ((CRFEncoderThread) arrayList.get(i9)).zeroone;
                }
                for (int i10 = 1; i10 < i3; i10++) {
                    for (int i11 = 0; i11 < encoderFeatureIndex.size(); i11++) {
                        double[] dArr2 = ((CRFEncoderThread) arrayList.get(0)).expected;
                        int i12 = i11;
                        dArr2[i12] = dArr2[i12] + ((CRFEncoderThread) arrayList.get(i10)).expected[i11];
                    }
                }
                int i13 = 0;
                if (z) {
                    for (int i14 = 0; i14 < encoderFeatureIndex.size(); i14++) {
                        ((CRFEncoderThread) arrayList.get(0)).obj += Math.abs(dArr[i14] / d);
                        if (dArr[i14] != Const.default_value_double) {
                            i13++;
                        }
                    }
                } else {
                    i13 = encoderFeatureIndex.size();
                    for (int i15 = 0; i15 < encoderFeatureIndex.size(); i15++) {
                        ((CRFEncoderThread) arrayList.get(0)).obj += (dArr[i15] * dArr[i15]) / (2.0d * d);
                        double[] dArr3 = ((CRFEncoderThread) arrayList.get(0)).expected;
                        int i16 = i15;
                        dArr3[i16] = dArr3[i16] + (dArr[i15] / d);
                    }
                }
                for (int i17 = 1; i17 < i3; i17++) {
                    ((CRFEncoderThread) arrayList.get(i17)).expected = null;
                }
                double abs = i8 == 0 ? 1.0d : Math.abs(d3 - ((CRFEncoderThread) arrayList.get(0)).obj) / d3;
                StringBuilder sb = new StringBuilder();
                sb.append("iter=").append(i8);
                sb.append(" terr=").append((1.0d * ((CRFEncoderThread) arrayList.get(0)).err) / i6);
                sb.append(" serr=").append((1.0d * ((CRFEncoderThread) arrayList.get(0)).zeroone) / list.size());
                sb.append(" act=").append(i13);
                sb.append(" obj=").append(((CRFEncoderThread) arrayList.get(0)).obj);
                sb.append(" diff=").append(abs);
                System.out.println(sb.toString());
                d3 = ((CRFEncoderThread) arrayList.get(0)).obj;
                i4 = abs < d2 ? i4 + 1 : 0;
                if (i8 > i || i4 == 3) {
                    break;
                }
                if (lbfgsOptimizer.optimize(encoderFeatureIndex.size(), dArr, ((CRFEncoderThread) arrayList.get(0)).obj, ((CRFEncoderThread) arrayList.get(0)).expected, z, d) <= 0) {
                    return false;
                }
                i8++;
            } catch (Exception e) {
                e.printStackTrace();
                return false;
            }
        }
        newFixedThreadPool.shutdown();
        try {
            newFixedThreadPool.awaitTermination(-1L, TimeUnit.SECONDS);
            return true;
        } catch (Exception e2) {
            e2.printStackTrace();
            System.err.println("fail waiting executor to shutdown");
            return true;
        }
    }

    public boolean runMIRA(List<TaggerImpl> list, EncoderFeatureIndex encoderFeatureIndex, double[] dArr, int i, double d, double d2, int i2, int i3) {
        Integer[] numArr = new Integer[list.size()];
        Arrays.fill((Object[]) numArr, (Object) 0);
        List asList = Arrays.asList(numArr);
        Double[] dArr2 = new Double[list.size()];
        Arrays.fill(dArr2, Double.valueOf(Const.default_value_double));
        List asList2 = Arrays.asList(dArr2);
        List<Double> asList3 = Arrays.asList(new Double[encoderFeatureIndex.size()]);
        if (i3 > 1) {
            System.err.println("WARN: MIRA does not support multi-threading");
        }
        int i4 = 0;
        int i5 = 0;
        for (int i6 = 0; i6 < list.size(); i6++) {
            i5 += list.get(i6).size();
        }
        for (int i7 = 0; i7 < i; i7++) {
            int i8 = 0;
            int i9 = 0;
            int i10 = 0;
            int i11 = 0;
            double d3 = 0.0d;
            for (int i12 = 0; i12 < list.size(); i12++) {
                if (((Integer) asList.get(i12)).intValue() < i2) {
                    i10++;
                    for (int i13 = 0; i13 < asList3.size(); i13++) {
                        asList3.set(i13, Double.valueOf(Const.default_value_double));
                    }
                    double collins = list.get(i12).collins(asList3);
                    int eval = list.get(i12).eval();
                    i9 += eval;
                    if (eval != 0) {
                        i8++;
                    }
                    if (eval == 0) {
                        asList.set(i12, Integer.valueOf(((Integer) asList.get(i12)).intValue() + 1));
                    } else {
                        asList.set(i12, 0);
                        double d4 = 0.0d;
                        for (int i14 = 0; i14 < asList3.size(); i14++) {
                            d4 += asList3.get(i14).doubleValue() * asList3.get(i14).doubleValue();
                        }
                        double max = Math.max(Const.default_value_double, (eval - collins) / d4);
                        if (((Double) asList2.get(i12)).doubleValue() + max > d) {
                            max = d - ((Double) asList2.get(i12)).doubleValue();
                            i11++;
                        } else {
                            d3 = Math.max(eval - collins, d3);
                        }
                        if (max > 1.0E-10d) {
                            asList2.set(i12, Double.valueOf(((Double) asList2.get(i12)).doubleValue() + max));
                            asList2.set(i12, Double.valueOf(Math.min(d, ((Double) asList2.get(i12)).doubleValue())));
                            for (int i15 = 0; i15 < asList3.size(); i15++) {
                                int i16 = i15;
                                dArr[i16] = dArr[i16] + (max * asList3.get(i15).doubleValue());
                            }
                        }
                    }
                }
            }
            double d5 = 0.0d;
            for (int i17 = 0; i17 < encoderFeatureIndex.size(); i17++) {
                d5 += dArr[i17] * dArr[i17];
            }
            StringBuilder sb = new StringBuilder();
            sb.append("iter=").append(i7);
            sb.append(" terr=").append((1.0d * i9) / i5);
            sb.append(" serr=").append((1.0d * i8) / list.size());
            sb.append(" act=").append(i10);
            sb.append(" uact=").append(i11);
            sb.append(" obj=").append(d5);
            sb.append(" kkt=").append(d3);
            System.out.println(sb.toString());
            if (d3 <= Const.default_value_double) {
                for (int i18 = 0; i18 < asList.size(); i18++) {
                    asList.set(i18, 0);
                }
                i4++;
            } else {
                i4 = 0;
            }
            if (i7 > i || i4 == 2) {
                return true;
            }
        }
        return true;
    }

    public static void main(String[] strArr) {
        if (strArr.length < 3) {
            System.err.println("incorrect No. of args");
            return;
        }
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        Encoder encoder = new Encoder();
        long time = new Date().getTime();
        if (encoder.learn(str, str2, str3, false, 100000, 1, 1.0E-4d, 1.0d, 1, 20, Algorithm.CRF_L2)) {
            System.out.println(new Date().getTime() - time);
        } else {
            System.err.println("error training model");
        }
    }
}
