// -*- C++ -*-

// Copyright 2006-2007 Deutsches Forschungszentrum fuer Kuenstliche Intelligenz
// or its licensors, as applicable.
//
// You may not use this file except under the terms of the accompanying license.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you
// may not use this file except in compliance with the License. You may
// obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Project:  ocr-bpnet - neural network classifier
// File: bpnet.cc
// Purpose: neural network classifier
// Responsible: Hagen Kaprykowsky (kapry@iupr.net)
// Reviewer: Yves Rangoni (rangoni@iupr.net)
// Primary Repository:
// Web Sites: www.iupr.org, www.dfki.de

#include <float.h>
#include "bpnet.h"
#include "ocr-utils.h"
#include "confusion-matrix.h"

#include "narray-ops.h"
#include "logger.h"

using namespace ocropus;
using namespace colib;
using namespace narray_ops;

namespace iupr_bpnet {

#define MIN_SCORE 1e-6
#define epsilon_stdev 1e-04

    // log the confusion matrix for the training set for each epoch
    Logger logger_confusion_train("confusion_train");
    // log the confusion matrix for the testing set for each epoch
    Logger logger_confusion_test("confusion_test");

    // log error rates for each epoch
    Logger logger_errors("error_rates");

    param_int freq_error_rates("freq_error_rates",1,"frequency of error_rates graphs plotting");

    template<class T>
    bool valid(T &v) {
        for(int i=0;i<v.length1d();i++)
            if(isnan(v.at1d(i))) {
                return false;
            }
        return true;
    }

    void shuffle_feat(objlist<floatarray> &v,intarray &c) {
        floatarray v_tmp;
        v_tmp.resize(v[0].length());
        int c_tmp;

        int n = v.length();
        for(int i=0;i<n-1;i++) {
            int target = rand()%(n-i)+i;
            copy(v_tmp,v[target]);
            copy(v[target],v[i]);
            copy(v[i],v_tmp);
            c_tmp = c[target];
            c[target] = c[i];
            c[i] = c_tmp;
        }
    }

    void write(FILE *stream,floatarray &input) {
        for(int i=0;i<input.length1d();i++) {
            fprintf(stream,"%.10g\n",double(input.at1d(i)));
        }
    }
    void read(floatarray &output,FILE *stream) {
        float value;
        for(int i=0;i<output.length1d();i++) {
            if(fscanf(stream,"%g",&value)!=1) {
                throw "bad file";
                return;
            }
            output.at1d(i) = value;
        }
    }

    // not to be used outside this scope
    //#define NSIGMOID 10000
    //#define SIGMOID_RANGE 15.0
    static const int NSIGMOID=10000;
    static const float SIGMOID_RANGE=15.0;

    // table of the sigmoid function
    static floatarray sigmoid_table(NSIGMOID);

    // slow exponential sigmoid function, not tabulated
    static float slow_sigmoid(float x) {
        return 1.0/(1.0+exp(-x));
    }

    // initialize the simgoid table
    static void init_sigmoid_table() {
        for(int i=0;i<NSIGMOID;i++) {
            sigmoid_table(i) = slow_sigmoid(i*SIGMOID_RANGE/NSIGMOID);
        }
    }

    // tabulated sigmoid function
    static inline float sigmoid(float x,bool &inited) {
        float abs_x = (x<0.0)?-x:x;
        int index;
        float abs_result;

        if (!inited) {
            init_sigmoid_table();
            inited = true;
        }

        if(abs_x>=SIGMOID_RANGE) {
            abs_result = 1.0;
        } else {
            index = int(NSIGMOID*abs_x/SIGMOID_RANGE);
            abs_result = (index>=NSIGMOID) ? 1.0 : sigmoid_table(index);
        }
        return (x<0.0) ? 1.0-abs_result:abs_result;
    }

    // float random number between low and high
    static float random_range(float low,float high) {
        //drand48 is obsolete, replaced by rand():
        float rnd = float(rand())/float(RAND_MAX);
        return (rnd * (high-low) + low);
    }

    // push activations one layer up using 2d weight matrix
    void bp_propagate_activations(floatarray &activations_input,int ninput,
                                  floatarray &activations_output,int noutput,
                                  floatarray &weights,floatarray &offsets,
                                  bool &sigmoid_inited) {
        float total;
        for(int i=0;i<noutput;i++) {
            total = offsets(i);
            for(int j=0;j<ninput;j++) {
                total += weights(i,j)*activations_input(j);
            }
            activations_output(i) = sigmoid(total,sigmoid_inited);
        }
    }

    // determine hidden layer error from error at the output units.
    void bp_propagate_deltas(floatarray &deltas_input,int noutput,
                             floatarray &activations_input,
                             floatarray &delta_output,int ninput,
                             floatarray &weights) {
        for(int j=0;j<ninput;j++) {
            float deriv = activations_input(j)*(1.0-activations_input(j));
            float total = 0.0;
            if(deriv<1e-5) deriv=1e-5;
            for(int i=0;i<noutput;i++)
                total += delta_output(i)*weights(i,j);
            deltas_input(j) = deriv*total;
        }
    }

    // weight update using backpropagation formula
    void bp_update_weights(floatarray &offsets,floatarray &delta_output,
                            int noutput,floatarray &activation_input,int ninput,
                            floatarray &weights,float eta) {
        for(int i=0;i<noutput;i++) {
            for(int j=0;j<ninput;j++) {
                weights(i,j) += eta*delta_output(i)*activation_input(j);
            }
            offsets(i) += eta*delta_output(i);
        }
    }

    // scale every dimension of the input
    // down to mean=0, std_dev=1
    void normalize_input_train(objlist<floatarray> &vectors,doublearray
                               &stdev,doublearray &m_x) {
        CHECK_ARG(stdev.length()==m_x.length());
        int nsamples = vectors.length();
        CHECK_CONDITION(nsamples>0);
        int ninput = m_x.length();
        doublearray m_xx;
        m_xx.resize(ninput);
        fill(m_xx,0.0f);
        fill(m_x,0.0f);
        fill(stdev,0.0f);

        for(int d=0;d<ninput;d++) {
            // calc mean and empirical variance
            for(int n=0;n<nsamples;n++) {
                m_x(d)  += vectors[n](d);
            }
            m_x(d) /= nsamples;
            for(int n=0;n<nsamples;n++) {
                double t = vectors[n](d) - m_x(d);
                m_xx(d) += t * t;
            }
            m_xx(d) /= nsamples;
            double sqr_stdev = m_xx(d);
            if(sqr_stdev < 0.)
                sqr_stdev = 0.;
            stdev(d) = sqrt(sqr_stdev);
            // normalize
            for(int n=0;n<nsamples;n++) {
                if(stdev(d)>epsilon_stdev) {
                    vectors[n](d) = (vectors[n](d)-m_x(d))/stdev(d);
                } else {    // var = 0: all the same;
                    vectors[n](d) = vectors[n](d)-m_x(d);
                }
            }

        } // end dim loop
        ASSERT(valid(m_x));
        ASSERT(valid(stdev));
    } // end normalize_input

    void normalize_input_retrain(objlist<floatarray> &vectors,doublearray
                                 &stdev,doublearray &m_x) {
        CHECK_ARG(stdev.length()==m_x.length());
        int ninput = m_x.length();
        int nsamples = vectors.length();
        for(int d=0;d<ninput;d++) {
            for(int n=0;n<nsamples;n++) {
                if(stdev(d)>epsilon_stdev) {
                    vectors[n](d) = (vectors[n](d)-m_x(d))/stdev(d);
                } else {
                    vectors[n](d) = vectors[n](d)-m_x(d); //var=0: all the same;
                }
            }
        }
    }


#undef NSIGMOID
#undef SIGMOID_RANGE

    class BpnetClassifier : public IClassifier {
    public:
        objlist<floatarray> vectors;
        narray<int> classes;
        narray<float> input;
        narray<float> hidden;
        narray<float> output;
        narray<float> hidden_deltas;
        narray<float> output_deltas;
        narray<float> weights_hidden_input;
        narray<float> hidden_offsets;
        narray<float> weights_output_hidden;
        narray<float> output_offsets;
        narray<float> error;
        narray<double> stdev;
        narray<double> m_x;
        bool init;
        bool sigmoid_inited;
        bool training;
        bool norm;
        bool shuffle;
        int ninput;
        int nhidden;
        int noutput;
        int epochs;
        float learningrate;
        float testportion;
        autodel<ConfusionMatrix> confusion_train;
        autodel<ConfusionMatrix> confusion_test;
        autodel<ConfusionMatrix> best_confusion_train;
        autodel<ConfusionMatrix> best_confusion_test;
        bool filedump;
        stdio fp;
        char buf[4096];
        narray<float> weights_hidden_input_best;
        narray<float> hidden_offsets_best;
        narray<float> weights_output_hidden_best;
        narray<float> output_offsets_best;

        float eta_plus,eta_minus,eta_tol;
        bool vlr;

        void common_presets() {

            init = false;
            sigmoid_inited = false;
            training = false;
            norm = true;
            shuffle = true;
            ninput = -1;
            nhidden = -1;
            noutput = -1;
            epochs = -1;
            learningrate = -1.0f;
            testportion = 0.0f;

            eta_plus = 1.10;
            eta_minus = 0.6;
            eta_tol = 1.02;
            vlr = false;
        }

        BpnetClassifier() {
            common_presets();
            filedump = false;
        }

        BpnetClassifier(const char* path_tmp_mlp) {
            common_presets();
            strncpy(buf,path_tmp_mlp,sizeof(buf));
            filedump = true;
        }

        void param(const char *name,double value) {
            if(!strcmp(name,"ninput")) ninput = int(value);
            else if(!strcmp(name,"nhidden")) nhidden = int(value);
            else if(!strcmp(name,"noutput")) noutput = int(value);
            else if(!strcmp(name,"epochs")) epochs = int(value);
            else if(!strcmp(name,"learningrate")) learningrate = float(value);
            else if(!strcmp(name,"testportion")) testportion = float(value);
            else if(!strcmp(name,"normalize")) norm = bool(value);
            else if(!strcmp(name,"shuffle")) shuffle = bool(value);
            else if(!strcmp(name,"filedump")) filedump = bool(value);

            else if(!strcmp(name,"vlr")) vlr = bool(value);
            else if(!strcmp(name,"eta_plus")) eta_plus = float(value);
            else if(!strcmp(name,"eta_minus")) eta_minus= float(value);
            else if(!strcmp(name,"eta_tol")) eta_tol = float(value);
            else throw "unknown parameter name";
        }

        void add(floatarray &v,int c) {
            CHECK_CONDITION(training);
            ASSERT(valid(v));
            if(ninput<0) ninput = v.dim(0); else CHECK_CONDITION(v.dim(0)==ninput);

            copy(vectors.push(),v);
            classes.push(c);
            ASSERT(vectors.length()==classes.length());
        }

        void start_training()  {
            training = true;
        }

        void start_classifying() {
            if(training) {
                if(noutput == -1)
                    noutput = max(classes) + 1;
                create();
                if(norm) {
                    printf("Normalizing... ");fflush(stdout);logger_errors("Normalizing");

                    if(init) {
                        normalize_input_retrain(vectors,stdev,m_x);
                    } else {
                        normalize_input_train(vectors,stdev,m_x);
                    }
                }
                printf("end\n");
                if(shuffle) {
                    printf("Shuffling... ");fflush(stdout);logger_errors("Shuffling");
                    shuffle_feat(vectors,classes);
                    printf("end\n");
                }
                init_backprop(0.001);
                train();
                dealloc_train();
                init = true;
            }
            training = false;
        }

        void start_classifying( intarray &confusion_matrix_train_out,
                                intarray &confusion_matrix_test_out) {
            start_classifying();
            best_confusion_train->get(confusion_matrix_train_out);
            best_confusion_test->get(confusion_matrix_test_out);
        }

        void seal() {
            vectors.dealloc();
            classes.dealloc();
        }

        void score(floatarray &result,floatarray &v) {
            CHECK_CONDITION(!training);
            if(v.length() != ninput) {
                throw_fmt("trained with input dimension %d, but got %d",
                           ninput, v.length());
            }
            result.resize(noutput);
            if(norm) {
                normalize_input_classify(v,stdev,m_x);
            }
            copy(input,v);
            forward();
            // Copy output layer to result avoiding scores less than MIN_SCORE
            for(int i=0;i<output.length();i++) {
                result(i) = max(MIN_SCORE, output(i));
            }
        }

        void save(FILE *stream) {
            CHECK_CONDITION(!training||filedump);
            if(!stream) {
                throw "cannot open output file for bp3 for writing";
            }
            fprintf(stream,"bp3-net %d %d %d %d\n",ninput,nhidden,noutput,norm);
            if(norm) {
                // write mean and variance
                for(int d=0;d<ninput;d++) {
                    fprintf(stream,"%f %f\n",m_x(d),stdev(d));
                }
                ASSERT(valid(m_x));
                ASSERT(valid(stdev));
            }
            write(stream,weights_hidden_input);
            write(stream,hidden_offsets);
            write(stream,weights_output_hidden);
            write(stream,output_offsets);
        }

        void load(FILE *stream) {
            CHECK_CONDITION(!training);
            double m_x_tmp,stdev_tmp;
            int norm_tmp;
            if(!stream) {
                throw "bad input format";
            }
            if(fscanf(stream,"bp3-net %d %d %d %d",&ninput,&nhidden,&noutput,&norm_tmp)!=4 ||
                ninput<1||ninput>1000000||nhidden<1||nhidden>1000000||noutput<1||
                noutput>1000000) {
                throw "bad input format";
            }
            norm = bool(norm_tmp);
            create();
            if(norm) {  // read mean and variance
                for(int d=0;d<ninput;d++) {
                    fscanf(stream,"%lf %lf\n",&m_x_tmp,&stdev_tmp);
                    m_x(d) = m_x_tmp;
                    stdev(d) = stdev_tmp;
                }
                ASSERT(valid(m_x));
                ASSERT(valid(stdev));
            }
            read(weights_hidden_input,stream);
            read(hidden_offsets,stream);
            read(weights_output_hidden,stream);
            read(output_offsets,stream);
            init = true;
        }

        void create() {
            if(!init) {
                weights_hidden_input.resize(nhidden,ninput);
                weights_output_hidden.resize(noutput,nhidden);
                hidden_offsets.resize(nhidden);
                output_offsets.resize(noutput);
                input.resize(ninput);
                hidden.resize(nhidden);
                output.resize(noutput);
                confusion_train = make_ConfusionMatrix(noutput);
                confusion_test = make_ConfusionMatrix(noutput);
                best_confusion_train = make_ConfusionMatrix(noutput);
                best_confusion_test = make_ConfusionMatrix(noutput);
                m_x.resize(ninput);
                stdev.resize(ninput);
            }
        }

        void forward_one_input(int sample_index, int &target, int &predicted) {

            copy(input,vectors[sample_index]);
            target = classes(sample_index);

            forward();
            predicted = argmax(output);
            ASSERT(predicted>=0 && predicted<noutput);
        }

        void test_on_db(int n1, int n2, float &_error, float &cls_error,
            ConfusionMatrix& cm) {
            ASSERT(n2>n1);
            int cls, predicted;
            _error = 0.0;
            cls_error = 0.0;
            cm.clear();
            for(int sample_index=n1;sample_index<n2;sample_index++) {
                forward_one_input(sample_index, cls, predicted);
                compute_error(cls);
                _error += norm2squared(error);
                cls_error += (predicted!=cls);
                cm.increment(cls,predicted);
            }
            cls_error /= (float(n2-n1));
        }

        void logging_train_test(const char* str,float _error,float cls_error,
            ConfusionMatrix& cm,int ep,float best_cls_error,float best_error) {
            char temp_string[4096];
            if(strcmp(str,"Train")==0) {
                sprintf(temp_string, "Train cls error: %10.4f%%\t"
                        "Train error: %10.4f\n",cls_error*100.,_error);
                logger_confusion_train.format(  "Train confusion matrix"
                                                " for epoch %d",ep);
                logger_confusion_train.confusion(cm);
            } else {
                sprintf(temp_string, "Test cls error:     %7.4f%%\t"
                                        "Test error:  %10.4f\t"
                                        "Best test cls error: %7.4f%%\t"
                                        "Best test error: %10.4f\n",
                                        cls_error*100.,_error,
                                        best_cls_error*100., best_error);
                logger_confusion_test.format(   "Test confusion matrix"
                                                " for epoch %d",ep);
                logger_confusion_test.confusion(cm);
            }
            printf(temp_string);logger_errors(temp_string);
        }


        void train() {
            CHECK_CONDITION(learningrate>0.0f);
            ASSERT(testportion>=0.0f&&testportion<=1.0f);

            floatarray histo_train_error;histo_train_error.resize(epochs+1);
            floatarray histo_train_cls_error;histo_train_cls_error.resize(epochs+1);
            floatarray histo_test_error;histo_test_error.resize(epochs+1);
            floatarray histo_test_cls_error;histo_test_cls_error.resize(epochs+1);
            floatarray histo_lr;histo_lr.resize(epochs+1);

            char temp_string[1024];
            int predicted = -1;
            int cls = -1;
            float train_cls_error;
            //float best_train_cls_error = FLT_MAX;
            float train_error;
            float best_train_error = FLT_MAX;

            float test_cls_error = 0.0f;
            float best_test_cls_error = FLT_MAX;
            float test_error = 0.0f;
            float best_test_error = FLT_MAX;

            int n = vectors.length();
            int ntrain = int(float(n)*(1.0f-testportion));
            int ntest = n-ntrain;

            sprintf(temp_string, "ep:%d nh:%d lr:%g tp:%g np:%d i:%d o:%d\n",
                    epochs,nhidden,learningrate,testportion,n,ninput,noutput);
            printf(temp_string);logger_errors(temp_string);

            sprintf(temp_string, "=== Start training on %d samples "
                    "(testing on %d) for %d epochs ===\n",ntrain,ntest,epochs);
            printf(temp_string);logger_errors(temp_string);

            //local copy of the current mlp
            floatarray weights_hidden_input_temp;
            floatarray hidden_offsets_temp;
            floatarray weights_output_hidden_temp;
            floatarray output_offsets_temp;

            float old_train_error = 0.0f;
            sprintf(temp_string, "Epoch: %d\tLR: %f\n\n",0,learningrate);
            printf(temp_string);logger_errors(temp_string);

            test_on_db(0,ntrain,train_error,train_cls_error,*confusion_train);
            old_train_error = best_train_error = train_error;
            logging_train_test( "Train",train_error,train_cls_error,
                                *confusion_train,-1,0,0);

            test_on_db(ntrain,n,test_error,test_cls_error,*confusion_test);
            best_test_error = test_error;
            best_test_cls_error = test_cls_error;
            logging_train_test( "Test ",test_error,test_cls_error,
                                *confusion_test,-1,
                                best_test_cls_error,best_test_error);

            histo_train_error[0] = train_error/ntrain;
            histo_train_cls_error[0] = train_cls_error*100.;
            histo_test_error[0] = test_error/ntest;
            histo_test_cls_error[0] = test_cls_error*100.;
            histo_lr[0] = learningrate;

            for(int epoch=0;epoch<epochs;epoch++) {
                sprintf(temp_string, "Epoch: %d\tLR: %f\n",epoch+1,learningrate);
                printf(temp_string);logger_errors(temp_string);
                //confusion_test->printReduced(stdout);
                //keep current network in memory
                copy(weights_hidden_input_temp, weights_hidden_input);
                copy(hidden_offsets_temp,       hidden_offsets);
                copy(weights_output_hidden_temp,weights_output_hidden);
                copy(output_offsets_temp,       output_offsets);

                //train
                for(int sample_index=0;sample_index<ntrain;sample_index++) {
                    forward_one_input(sample_index, cls, predicted);
                    compute_error(cls);
                    backward();
                    update();
                    if(sample_index % 1000 == 0) {
                        printf("%d/%d\r", sample_index, ntrain);fflush(stdout);
                    }
                }
                printf("\n");

                test_on_db(0,ntrain,train_error,train_cls_error,*confusion_train);
                logging_train_test( "Train",train_error,train_cls_error,
                                    *confusion_train,epoch,0,0);

                //test
                if(testportion>0.0f) {
                    test_on_db(ntrain,n,test_error,test_cls_error,*confusion_test);

                    if(test_cls_error < best_test_cls_error) {
                        best_test_cls_error = test_cls_error;
                    }
                    if(test_error < best_test_error) {
                        if(filedump) {
                            copy(weights_hidden_input_best,weights_hidden_input);
                            copy(hidden_offsets_best,hidden_offsets);
                            copy(weights_output_hidden_best,weights_output_hidden);
                            copy(output_offsets_best,output_offsets);
                            save(stdio(buf,"w"));
                        }
                        best_test_error = test_error;
                        best_confusion_train->set(*confusion_train);
                        best_confusion_test->set(*confusion_test);
                    }
                    logging_train_test( "Test",test_error,test_cls_error,
                                        *confusion_test,epoch,
                                        best_test_cls_error,best_test_error);
                    histo_test_error[epoch+1] = test_error/ntest;
                    histo_test_cls_error[epoch+1] = test_cls_error*100;
                } else {
                    if(train_error < best_train_error) {
                        if(filedump) {
                            copy(weights_hidden_input_best,weights_hidden_input);
                            copy(hidden_offsets_best,hidden_offsets);
                            copy(weights_output_hidden_best,weights_output_hidden);
                            copy(output_offsets_best,output_offsets);
                            save(stdio(buf,"w"));
                        }
                        best_train_error = train_error;
                        best_confusion_train->set(*confusion_train);
                    }
                }
                histo_train_error[epoch+1] = train_error/ntrain;
                histo_train_cls_error[epoch+1] = train_cls_error*100.;
                histo_lr[epoch+1] = learningrate;

                if(vlr) {
                    if(train_error / old_train_error >= eta_tol) {
                        learningrate *= eta_minus;
                        printf("eta -\n");
                        logger_errors("eta -");
                        copy(weights_hidden_input,  weights_hidden_input_temp);
                        copy(hidden_offsets,        hidden_offsets_temp);
                        copy(weights_output_hidden, weights_output_hidden_temp);
                        copy(output_offsets,        output_offsets_temp);
                    } else {
                        if(train_error < old_train_error) {
                            learningrate *= eta_plus;
                            printf("eta +\n");
                            logger_errors("eta +");
                        } else {
                            printf("eta =\n");
                            logger_errors("eta =");
                        }
                        old_train_error = train_error;
                    }
                }
                if((freq_error_rates>0) && ((epoch%freq_error_rates)==0 || (epoch==0)))
                logger_errors.train_test_curves(
                    histo_train_error,histo_train_cls_error,
                    histo_test_error,histo_test_cls_error,
                    histo_lr,epoch+1);
                //confusion_test->printReduced(stdout);
            }
            if(filedump) {  // get the best weights
                copy(weights_hidden_input,weights_hidden_input_best);
                copy(hidden_offsets,hidden_offsets_best);
                copy(weights_output_hidden,weights_output_hidden_best);
                copy(output_offsets,output_offsets_best);
            }
            fflush(stdout);

            logger_errors.train_test_curves(
                    histo_train_error,histo_train_cls_error,
                    histo_test_error,histo_test_cls_error,
                    histo_lr,epochs);
        }

        void init_backprop(float range) {
            error.resize(noutput);
            hidden_deltas.resize(nhidden);
            output_deltas.resize(noutput);
            fill(error,0.0f);
            fill(hidden_deltas,0.0f);
            fill(output_deltas,0.0f);
            if(!init) {
                for(int i=0;i<nhidden;i++) {
                    for(int j=0;j<ninput;j++) {
                        weights_hidden_input(i,j) = random_range(-range,range);
                    }
                }
                for(int i=0;i<nhidden;i++) {
                    hidden_offsets(i) = random_range(-range,range);
                }
                for(int i=0;i<noutput;i++) {
                    for(int j=0;j<nhidden;j++) {
                        weights_output_hidden(i,j) = random_range(-range,range);
                    }
                }
                for(int i=0;i<noutput;i++) {
                    output_offsets(i) = random_range(-range,range);
                }
            }
        }

        void dealloc_train() {
            error.dealloc();
            hidden_deltas.dealloc();
            output_deltas.dealloc();
            classes.clear();
            vectors.clear();
        }

        void forward() {
            bp_propagate_activations(input,ninput,hidden,nhidden,
                                     weights_hidden_input,hidden_offsets,
                                     sigmoid_inited);
            bp_propagate_activations(hidden,nhidden,output,noutput,
                                     weights_output_hidden,output_offsets,
                                     sigmoid_inited);
        }

        void compute_error(int cls) {
            copy(error,output);
            error(cls) -= 1;
        }

        void backward() {
            int i;
            for(i=0;i<noutput;i++) {
                float deriv = output(i)*(1.0-output(i));
                output_deltas(i) = -deriv*error(i);
            }
            bp_propagate_deltas(hidden_deltas,noutput,
                                hidden,output_deltas,nhidden,
                                weights_output_hidden);
        }

        void update() {
            bp_update_weights(output_offsets,output_deltas,noutput,
                              hidden,nhidden,weights_output_hidden,
                              learningrate);
            bp_update_weights(hidden_offsets,hidden_deltas,nhidden,
                              input,ninput,weights_hidden_input,
                              learningrate);
        }
    };
}

namespace ocropus {
    IClassifier *make_BpnetClassifier() {
        using namespace iupr_bpnet;
        return new BpnetClassifier();
    }
    IClassifier *make_BpnetClassifierDumpIntoFile(const char *path) {
        using namespace iupr_bpnet;
        return new BpnetClassifier(path);
    }
};
