001    /**
002     * Copyright (C) 2007-2008, Jens Lehmann
003     *
004     * This file is part of DL-Learner.
005     * 
006     * DL-Learner is free software; you can redistribute it and/or modify
007     * it under the terms of the GNU General Public License as published by
008     * the Free Software Foundation; either version 3 of the License, or
009     * (at your option) any later version.
010     *
011     * DL-Learner is distributed in the hope that it will be useful,
012     * but WITHOUT ANY WARRANTY; without even the implied warranty of
013     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014     * GNU General Public License for more details.
015     *
016     * You should have received a copy of the GNU General Public License
017     * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018     *
019     */
020    package org.dllearner.scripts;
021    
022    import java.io.File;
023    import java.io.FileNotFoundException;
024    import java.text.DecimalFormat;
025    import java.util.Collections;
026    import java.util.HashSet;
027    import java.util.LinkedList;
028    import java.util.List;
029    import java.util.Random;
030    import java.util.Set;
031    
032    import org.apache.log4j.ConsoleAppender;
033    import org.apache.log4j.Level;
034    import org.apache.log4j.Logger;
035    import org.apache.log4j.SimpleLayout;
036    import org.dllearner.cli.Start;
037    import org.dllearner.core.ComponentInitException;
038    import org.dllearner.core.ComponentManager;
039    import org.dllearner.core.LearningAlgorithm;
040    import org.dllearner.core.LearningProblem;
041    import org.dllearner.core.ReasonerComponent;
042    import org.dllearner.core.owl.Description;
043    import org.dllearner.core.owl.Individual;
044    import org.dllearner.learningproblems.PosNegLP;
045    import org.dllearner.learningproblems.PosOnlyLP;
046    import org.dllearner.parser.ParseException;
047    import org.dllearner.utilities.Helper;
048    import org.dllearner.utilities.datastructures.Datastructures;
049    import org.dllearner.utilities.statistics.Stat;
050    
051    /**
052     * Performs cross validation for the given problem. Supports
053     * k-fold cross-validation and leave-one-out cross-validation.
054     * 
055     * @author Jens Lehmann
056     *
057     */
058    public class CrossValidation {
059    
060            private static Logger logger = Logger.getRootLogger();  
061            
062            // statistical values
063            private Stat runtime = new Stat();
064            private Stat accuracy = new Stat();
065            private Stat length = new Stat();       
066            
067            public static void main(String[] args) {
068                    File file = new File(args[0]);
069                    
070                    boolean leaveOneOut = false;
071                    int folds = 10;
072                    
073                    // use second argument as number of folds; if not specified
074                    // leave one out cross validation is used
075                    if(args.length > 1)
076                            folds = Integer.parseInt(args[1]);
077                    else
078                            leaveOneOut = true;
079                    
080                    if(folds < 2) {
081                            System.out.println("At least 2 fold needed.");
082                            System.exit(0);
083                    }
084                    
085                    // create logger (a simple logger which outputs
086                    // its messages to the console)
087                    SimpleLayout layout = new SimpleLayout();
088                    ConsoleAppender consoleAppender = new ConsoleAppender(layout);
089                    logger.removeAllAppenders();
090                    logger.addAppender(consoleAppender);
091                    logger.setLevel(Level.WARN);
092                    // disable OWL API info output
093                    java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING);
094                    
095                    new CrossValidation(file, folds, leaveOneOut);
096                    
097            }
098            
099            public CrossValidation(File file, int folds, boolean leaveOneOut) {
100                    this(file, folds, leaveOneOut, null);
101            }
102                            
103            public CrossValidation(File file, int folds, boolean leaveOneOut, LearningAlgorithm la) {               
104                    
105                    DecimalFormat df = new DecimalFormat(); 
106                    ComponentManager cm = ComponentManager.getInstance();
107                    
108                    // the first read of the file is used to detect the examples
109                    // and set up the splits correctly according to our validation
110                    // method
111                    Start start = null;
112                    try {
113                            start = new Start(file);
114                    } catch (ComponentInitException e) {
115                            // TODO Auto-generated catch block
116                            e.printStackTrace();
117                    } catch (FileNotFoundException e) {
118                            // TODO Auto-generated catch block
119                            e.printStackTrace();
120                    } catch (ParseException e) {
121                            // TODO Auto-generated catch block
122                            e.printStackTrace();
123                    }
124                    
125                    LearningProblem lp = start.getLearningProblem();
126    //              ReasonerComponent rs = start.getReasonerComponent();
127    //              start.getReasonerComponent().releaseKB();
128    
129                    // the training and test sets used later on
130                    List<Set<Individual>> trainingSetsPos = new LinkedList<Set<Individual>>();
131                    List<Set<Individual>> trainingSetsNeg = new LinkedList<Set<Individual>>();
132                    List<Set<Individual>> testSetsPos = new LinkedList<Set<Individual>>();
133                    List<Set<Individual>> testSetsNeg = new LinkedList<Set<Individual>>();
134                    
135                    if(lp instanceof PosNegLP) {
136    
137                            // get examples and shuffle them to 
138                            Set<Individual> posExamples = ((PosNegLP)lp).getPositiveExamples();
139                            List<Individual> posExamplesList = new LinkedList<Individual>(posExamples);
140                            Collections.shuffle(posExamplesList, new Random(1));                    
141                            Set<Individual> negExamples = ((PosNegLP)lp).getNegativeExamples();
142                            List<Individual> negExamplesList = new LinkedList<Individual>(negExamples);
143                            Collections.shuffle(negExamplesList, new Random(2));
144                            
145                            // sanity check whether nr. of folds makes sense for this benchmark
146                            if(!leaveOneOut && (posExamples.size()<folds && negExamples.size()<folds)) {
147                                    System.out.println("The number of folds is higher than the number of "
148                                                    + "positive/negative examples. This can result in empty test sets. Exiting.");
149                                    System.exit(0);
150                            }
151                            
152                            if(leaveOneOut) {
153                                    // note that leave-one-out is not identical to k-fold with
154                                    // k = nr. of examples in the current implementation, because
155                                    // with n folds and n examples there is no guarantee that a fold
156                                    // is never empty (this is an implementation issue)
157                                    int nrOfExamples = posExamples.size() + negExamples.size();
158                                    for(int i = 0; i < nrOfExamples; i++) {
159                                            // ...
160                                    }
161                                    System.out.println("Leave-one-out not supported yet.");
162                                    System.exit(1);
163                            } else {
164                                    // calculating where to split the sets, ; note that we split
165                                    // positive and negative examples separately such that the 
166                                    // distribution of positive and negative examples remains similar
167                                    // (note that there better but more complex ways to implement this,
168                                    // which guarantee that the sum of the elements of a fold for pos
169                                    // and neg differs by at most 1 - it can differ by 2 in our implementation,
170                                    // e.g. with 3 folds, 4 pos. examples, 4 neg. examples)
171                                    int[] splitsPos = calculateSplits(posExamples.size(),folds);
172                                    int[] splitsNeg = calculateSplits(negExamples.size(),folds);
173                                    
174    //                              System.out.println(splitsPos[0]);
175    //                              System.out.println(splitsNeg[0]);
176                                    
177                                    // calculating training and test sets
178                                    for(int i=0; i<folds; i++) {
179                                            Set<Individual> testPos = getTestingSet(posExamplesList, splitsPos, i);
180                                            Set<Individual> testNeg = getTestingSet(negExamplesList, splitsNeg, i);
181                                            testSetsPos.add(i, testPos);
182                                            testSetsNeg.add(i, testNeg);
183                                            trainingSetsPos.add(i, getTrainingSet(posExamples, testPos));
184                                            trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg));                           
185                                    }       
186                                    
187                            }
188                            
189                    } else if(lp instanceof PosOnlyLP) {
190                            System.out.println("Cross validation for positive only learning not supported yet.");
191                            System.exit(0);
192                            // Set<Individual> posExamples = ((PosOnlyLP)lp).getPositiveExamples();
193                            // int[] splits = calculateSplits(posExamples.size(),folds);
194                    } else {
195                            System.out.println("Cross validation for learning problem " + lp + " not supported.");
196                            System.exit(0);
197                    }
198                    
199                    // run the algorithm
200                    for(int currFold=0; currFold<folds; currFold++) {
201                            // we always perform a full initialisation to make sure that
202                            // no objects are reused
203                            try {
204                                    start = new Start(file);
205                            } catch (ComponentInitException e) {
206                                    e.printStackTrace();
207                            } catch (FileNotFoundException e) {
208                                    // TODO Auto-generated catch block
209                                    e.printStackTrace();
210                            } catch (ParseException e) {
211                                    // TODO Auto-generated catch block
212                                    e.printStackTrace();
213                            }
214                            lp = start.getLearningProblem();
215                            Set<String> pos = Datastructures.individualSetToStringSet(trainingSetsPos.get(currFold));
216                            Set<String> neg = Datastructures.individualSetToStringSet(trainingSetsNeg.get(currFold));
217                            cm.applyConfigEntry(lp, "positiveExamples", pos);
218                            cm.applyConfigEntry(lp, "negativeExamples", neg);
219    //                      System.out.println("pos: " + pos.size());
220    //                      System.out.println("neg: " + neg.size());
221    //                      System.exit(0);
222                            
223                            la = start.getLearningAlgorithm();
224                            // init again, because examples have changed
225                            try {
226    //                              start.getReasonerComponent().init();                            
227                                    lp.init();
228                                    la.init();
229                            } catch (ComponentInitException e) {
230                                    // TODO Auto-generated catch block
231                                    e.printStackTrace();
232                            }
233                            
234                            long algorithmStartTime = System.nanoTime();
235                            la.start();
236                            long algorithmDuration = System.nanoTime() - algorithmStartTime;
237                            runtime.addNumber(algorithmDuration/(double)1000000000);
238                            
239                            Description concept = la.getCurrentlyBestDescription();
240                            
241                            ReasonerComponent rs = start.getReasonerComponent();
242                            Set<Individual> tmp = rs.hasType(concept, testSetsPos.get(currFold));
243                            Set<Individual> tmp2 = Helper.difference(testSetsPos.get(currFold), tmp);
244                            Set<Individual> tmp3 = rs.hasType(concept, testSetsNeg.get(currFold));
245                            
246                            System.out.println("test set errors pos: " + tmp2);
247                            System.out.println("test set errors neg: " + tmp3);
248                            
249                            // calculate training accuracies 
250                            int trainingCorrectPosClassified = getCorrectPosClassified(rs, concept, trainingSetsPos.get(currFold));
251                            int trainingCorrectNegClassified = getCorrectNegClassified(rs, concept, trainingSetsNeg.get(currFold));
252                            int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified;
253                            double trainingAccuracy = 100*((double)trainingCorrectExamples/(trainingSetsPos.get(currFold).size()+
254                                            trainingSetsNeg.get(currFold).size()));                 
255                            
256                            // calculate test accuracies
257                            int correctPosClassified = getCorrectPosClassified(rs, concept, testSetsPos.get(currFold));
258                            int correctNegClassified = getCorrectNegClassified(rs, concept, testSetsNeg.get(currFold));
259                            int correctExamples = correctPosClassified + correctNegClassified;
260                            double currAccuracy = 100*((double)correctExamples/(testSetsPos.get(currFold).size()+
261                                            testSetsNeg.get(currFold).size()));
262                            accuracy.addNumber(currAccuracy);
263                            
264                            length.addNumber(concept.getLength());
265                            
266                            System.out.println("fold " + currFold + " (" + file + "):");
267                            System.out.println("  training: " + pos.size() + " positive and " + neg.size() + " negative examples");
268                            System.out.println("  testing: " + correctPosClassified + "/" + testSetsPos.get(currFold).size() + " correct positives, " 
269                                            + correctNegClassified + "/" + testSetsNeg.get(currFold).size() + " correct negatives");
270                            System.out.println("  concept: " + concept);
271                            System.out.println("  accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy) + "% on training set)");
272                            System.out.println("  length: " + df.format(concept.getLength()));
273                            System.out.println("  runtime: " + df.format(algorithmDuration/(double)1000000000) + "s");
274                            
275                            // free all resources
276                            rs.releaseKB();
277                            cm.freeAllComponents();                 
278                    }
279                    
280                    System.out.println();
281                    System.out.println("Finished " + folds + "-folds cross-validation on " + file + ".");
282                    System.out.println("runtime: " + statOutput(df, runtime, "s"));
283                    System.out.println("length: " + statOutput(df, length, ""));
284                    System.out.println("accuracy: " + statOutput(df, accuracy, "%"));
285                    
286            }
287            
288            private int getCorrectPosClassified(ReasonerComponent rs, Description concept, Set<Individual> testSetPos) {
289                    return rs.hasType(concept, testSetPos).size();
290            }
291            
292            private int getCorrectNegClassified(ReasonerComponent rs, Description concept, Set<Individual> testSetNeg) {
293                    return testSetNeg.size() - rs.hasType(concept, testSetNeg).size();
294            }
295            
296            private Set<Individual> getTestingSet(List<Individual> examples, int[] splits, int fold) {
297                    int fromIndex;
298                    // we either start from 0 or after the last fold ended
299                    if(fold == 0)
300                            fromIndex = 0;
301                    else
302                            fromIndex = splits[fold-1];
303                    // the split corresponds to the ends of the folds
304                    int toIndex = splits[fold];
305                    
306    //              System.out.println("from " + fromIndex + " to " + toIndex);
307                    
308                    Set<Individual> testingSet = new HashSet<Individual>();
309                    // +1 because 2nd element is exclusive in subList method
310                    testingSet.addAll(examples.subList(fromIndex, toIndex));
311                    return testingSet;
312            }
313            
314            private Set<Individual> getTrainingSet(Set<Individual> examples, Set<Individual> testingSet) {
315                    return Helper.difference(examples, testingSet);
316            }
317            
318            // takes nr. of examples and the nr. of folds for this examples;
319            // returns an array which says where each fold ends, i.e.
320            // splits[i] is the index of the last element of fold i in the examples
321            private int[] calculateSplits(int nrOfExamples, int folds) {
322                    int[] splits = new int[folds];
323                    for(int i=1; i<=folds; i++) {
324                            // we always round up to the next integer
325                            splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds);
326                    }
327                    return splits;
328            }
329            
330            public static String statOutput(DecimalFormat df, Stat stat, String unit) {
331                    String str = "av. " + df.format(stat.getMean()) + unit;
332                    str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; ";
333                    str += "min " + df.format(stat.getMin()) + unit + "; ";
334                    str += "max " + df.format(stat.getMax()) + unit + ")";          
335                    return str;
336            }
337    
338            public Stat getAccuracy() {
339                    return accuracy;
340            }
341    
342            public Stat getLength() {
343                    return length;
344            }
345    
346            public Stat getRuntime() {
347                    return runtime;
348            }
349    
350    }