001 /**
002 * Copyright (C) 2007-2008, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program. If not, see <http://www.gnu.org/licenses/>.
018 *
019 */
020 package org.dllearner.scripts;
021
022 import java.io.File;
023 import java.io.FileNotFoundException;
024 import java.text.DecimalFormat;
025 import java.util.Collections;
026 import java.util.HashSet;
027 import java.util.LinkedList;
028 import java.util.List;
029 import java.util.Random;
030 import java.util.Set;
031
032 import org.apache.log4j.ConsoleAppender;
033 import org.apache.log4j.Level;
034 import org.apache.log4j.Logger;
035 import org.apache.log4j.SimpleLayout;
036 import org.dllearner.cli.Start;
037 import org.dllearner.core.ComponentInitException;
038 import org.dllearner.core.ComponentManager;
039 import org.dllearner.core.LearningAlgorithm;
040 import org.dllearner.core.LearningProblem;
041 import org.dllearner.core.ReasonerComponent;
042 import org.dllearner.core.owl.Description;
043 import org.dllearner.core.owl.Individual;
044 import org.dllearner.learningproblems.PosNegLP;
045 import org.dllearner.learningproblems.PosOnlyLP;
046 import org.dllearner.parser.ParseException;
047 import org.dllearner.utilities.Helper;
048 import org.dllearner.utilities.datastructures.Datastructures;
049 import org.dllearner.utilities.statistics.Stat;
050
051 /**
052 * Performs cross validation for the given problem. Supports
053 * k-fold cross-validation and leave-one-out cross-validation.
054 *
055 * @author Jens Lehmann
056 *
057 */
058 public class CrossValidation {
059
060 private static Logger logger = Logger.getRootLogger();
061
062 // statistical values
063 private Stat runtime = new Stat();
064 private Stat accuracy = new Stat();
065 private Stat length = new Stat();
066
067 public static void main(String[] args) {
068 File file = new File(args[0]);
069
070 boolean leaveOneOut = false;
071 int folds = 10;
072
073 // use second argument as number of folds; if not specified
074 // leave one out cross validation is used
075 if(args.length > 1)
076 folds = Integer.parseInt(args[1]);
077 else
078 leaveOneOut = true;
079
080 if(folds < 2) {
081 System.out.println("At least 2 fold needed.");
082 System.exit(0);
083 }
084
085 // create logger (a simple logger which outputs
086 // its messages to the console)
087 SimpleLayout layout = new SimpleLayout();
088 ConsoleAppender consoleAppender = new ConsoleAppender(layout);
089 logger.removeAllAppenders();
090 logger.addAppender(consoleAppender);
091 logger.setLevel(Level.WARN);
092 // disable OWL API info output
093 java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING);
094
095 new CrossValidation(file, folds, leaveOneOut);
096
097 }
098
099 public CrossValidation(File file, int folds, boolean leaveOneOut) {
100 this(file, folds, leaveOneOut, null);
101 }
102
103 public CrossValidation(File file, int folds, boolean leaveOneOut, LearningAlgorithm la) {
104
105 DecimalFormat df = new DecimalFormat();
106 ComponentManager cm = ComponentManager.getInstance();
107
108 // the first read of the file is used to detect the examples
109 // and set up the splits correctly according to our validation
110 // method
111 Start start = null;
112 try {
113 start = new Start(file);
114 } catch (ComponentInitException e) {
115 // TODO Auto-generated catch block
116 e.printStackTrace();
117 } catch (FileNotFoundException e) {
118 // TODO Auto-generated catch block
119 e.printStackTrace();
120 } catch (ParseException e) {
121 // TODO Auto-generated catch block
122 e.printStackTrace();
123 }
124
125 LearningProblem lp = start.getLearningProblem();
126 // ReasonerComponent rs = start.getReasonerComponent();
127 // start.getReasonerComponent().releaseKB();
128
129 // the training and test sets used later on
130 List<Set<Individual>> trainingSetsPos = new LinkedList<Set<Individual>>();
131 List<Set<Individual>> trainingSetsNeg = new LinkedList<Set<Individual>>();
132 List<Set<Individual>> testSetsPos = new LinkedList<Set<Individual>>();
133 List<Set<Individual>> testSetsNeg = new LinkedList<Set<Individual>>();
134
135 if(lp instanceof PosNegLP) {
136
137 // get examples and shuffle them to
138 Set<Individual> posExamples = ((PosNegLP)lp).getPositiveExamples();
139 List<Individual> posExamplesList = new LinkedList<Individual>(posExamples);
140 Collections.shuffle(posExamplesList, new Random(1));
141 Set<Individual> negExamples = ((PosNegLP)lp).getNegativeExamples();
142 List<Individual> negExamplesList = new LinkedList<Individual>(negExamples);
143 Collections.shuffle(negExamplesList, new Random(2));
144
145 // sanity check whether nr. of folds makes sense for this benchmark
146 if(!leaveOneOut && (posExamples.size()<folds && negExamples.size()<folds)) {
147 System.out.println("The number of folds is higher than the number of "
148 + "positive/negative examples. This can result in empty test sets. Exiting.");
149 System.exit(0);
150 }
151
152 if(leaveOneOut) {
153 // note that leave-one-out is not identical to k-fold with
154 // k = nr. of examples in the current implementation, because
155 // with n folds and n examples there is no guarantee that a fold
156 // is never empty (this is an implementation issue)
157 int nrOfExamples = posExamples.size() + negExamples.size();
158 for(int i = 0; i < nrOfExamples; i++) {
159 // ...
160 }
161 System.out.println("Leave-one-out not supported yet.");
162 System.exit(1);
163 } else {
164 // calculating where to split the sets, ; note that we split
165 // positive and negative examples separately such that the
166 // distribution of positive and negative examples remains similar
167 // (note that there better but more complex ways to implement this,
168 // which guarantee that the sum of the elements of a fold for pos
169 // and neg differs by at most 1 - it can differ by 2 in our implementation,
170 // e.g. with 3 folds, 4 pos. examples, 4 neg. examples)
171 int[] splitsPos = calculateSplits(posExamples.size(),folds);
172 int[] splitsNeg = calculateSplits(negExamples.size(),folds);
173
174 // System.out.println(splitsPos[0]);
175 // System.out.println(splitsNeg[0]);
176
177 // calculating training and test sets
178 for(int i=0; i<folds; i++) {
179 Set<Individual> testPos = getTestingSet(posExamplesList, splitsPos, i);
180 Set<Individual> testNeg = getTestingSet(negExamplesList, splitsNeg, i);
181 testSetsPos.add(i, testPos);
182 testSetsNeg.add(i, testNeg);
183 trainingSetsPos.add(i, getTrainingSet(posExamples, testPos));
184 trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg));
185 }
186
187 }
188
189 } else if(lp instanceof PosOnlyLP) {
190 System.out.println("Cross validation for positive only learning not supported yet.");
191 System.exit(0);
192 // Set<Individual> posExamples = ((PosOnlyLP)lp).getPositiveExamples();
193 // int[] splits = calculateSplits(posExamples.size(),folds);
194 } else {
195 System.out.println("Cross validation for learning problem " + lp + " not supported.");
196 System.exit(0);
197 }
198
199 // run the algorithm
200 for(int currFold=0; currFold<folds; currFold++) {
201 // we always perform a full initialisation to make sure that
202 // no objects are reused
203 try {
204 start = new Start(file);
205 } catch (ComponentInitException e) {
206 e.printStackTrace();
207 } catch (FileNotFoundException e) {
208 // TODO Auto-generated catch block
209 e.printStackTrace();
210 } catch (ParseException e) {
211 // TODO Auto-generated catch block
212 e.printStackTrace();
213 }
214 lp = start.getLearningProblem();
215 Set<String> pos = Datastructures.individualSetToStringSet(trainingSetsPos.get(currFold));
216 Set<String> neg = Datastructures.individualSetToStringSet(trainingSetsNeg.get(currFold));
217 cm.applyConfigEntry(lp, "positiveExamples", pos);
218 cm.applyConfigEntry(lp, "negativeExamples", neg);
219 // System.out.println("pos: " + pos.size());
220 // System.out.println("neg: " + neg.size());
221 // System.exit(0);
222
223 la = start.getLearningAlgorithm();
224 // init again, because examples have changed
225 try {
226 // start.getReasonerComponent().init();
227 lp.init();
228 la.init();
229 } catch (ComponentInitException e) {
230 // TODO Auto-generated catch block
231 e.printStackTrace();
232 }
233
234 long algorithmStartTime = System.nanoTime();
235 la.start();
236 long algorithmDuration = System.nanoTime() - algorithmStartTime;
237 runtime.addNumber(algorithmDuration/(double)1000000000);
238
239 Description concept = la.getCurrentlyBestDescription();
240
241 ReasonerComponent rs = start.getReasonerComponent();
242 Set<Individual> tmp = rs.hasType(concept, testSetsPos.get(currFold));
243 Set<Individual> tmp2 = Helper.difference(testSetsPos.get(currFold), tmp);
244 Set<Individual> tmp3 = rs.hasType(concept, testSetsNeg.get(currFold));
245
246 System.out.println("test set errors pos: " + tmp2);
247 System.out.println("test set errors neg: " + tmp3);
248
249 // calculate training accuracies
250 int trainingCorrectPosClassified = getCorrectPosClassified(rs, concept, trainingSetsPos.get(currFold));
251 int trainingCorrectNegClassified = getCorrectNegClassified(rs, concept, trainingSetsNeg.get(currFold));
252 int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified;
253 double trainingAccuracy = 100*((double)trainingCorrectExamples/(trainingSetsPos.get(currFold).size()+
254 trainingSetsNeg.get(currFold).size()));
255
256 // calculate test accuracies
257 int correctPosClassified = getCorrectPosClassified(rs, concept, testSetsPos.get(currFold));
258 int correctNegClassified = getCorrectNegClassified(rs, concept, testSetsNeg.get(currFold));
259 int correctExamples = correctPosClassified + correctNegClassified;
260 double currAccuracy = 100*((double)correctExamples/(testSetsPos.get(currFold).size()+
261 testSetsNeg.get(currFold).size()));
262 accuracy.addNumber(currAccuracy);
263
264 length.addNumber(concept.getLength());
265
266 System.out.println("fold " + currFold + " (" + file + "):");
267 System.out.println(" training: " + pos.size() + " positive and " + neg.size() + " negative examples");
268 System.out.println(" testing: " + correctPosClassified + "/" + testSetsPos.get(currFold).size() + " correct positives, "
269 + correctNegClassified + "/" + testSetsNeg.get(currFold).size() + " correct negatives");
270 System.out.println(" concept: " + concept);
271 System.out.println(" accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy) + "% on training set)");
272 System.out.println(" length: " + df.format(concept.getLength()));
273 System.out.println(" runtime: " + df.format(algorithmDuration/(double)1000000000) + "s");
274
275 // free all resources
276 rs.releaseKB();
277 cm.freeAllComponents();
278 }
279
280 System.out.println();
281 System.out.println("Finished " + folds + "-folds cross-validation on " + file + ".");
282 System.out.println("runtime: " + statOutput(df, runtime, "s"));
283 System.out.println("length: " + statOutput(df, length, ""));
284 System.out.println("accuracy: " + statOutput(df, accuracy, "%"));
285
286 }
287
288 private int getCorrectPosClassified(ReasonerComponent rs, Description concept, Set<Individual> testSetPos) {
289 return rs.hasType(concept, testSetPos).size();
290 }
291
292 private int getCorrectNegClassified(ReasonerComponent rs, Description concept, Set<Individual> testSetNeg) {
293 return testSetNeg.size() - rs.hasType(concept, testSetNeg).size();
294 }
295
296 private Set<Individual> getTestingSet(List<Individual> examples, int[] splits, int fold) {
297 int fromIndex;
298 // we either start from 0 or after the last fold ended
299 if(fold == 0)
300 fromIndex = 0;
301 else
302 fromIndex = splits[fold-1];
303 // the split corresponds to the ends of the folds
304 int toIndex = splits[fold];
305
306 // System.out.println("from " + fromIndex + " to " + toIndex);
307
308 Set<Individual> testingSet = new HashSet<Individual>();
309 // +1 because 2nd element is exclusive in subList method
310 testingSet.addAll(examples.subList(fromIndex, toIndex));
311 return testingSet;
312 }
313
314 private Set<Individual> getTrainingSet(Set<Individual> examples, Set<Individual> testingSet) {
315 return Helper.difference(examples, testingSet);
316 }
317
318 // takes nr. of examples and the nr. of folds for this examples;
319 // returns an array which says where each fold ends, i.e.
320 // splits[i] is the index of the last element of fold i in the examples
321 private int[] calculateSplits(int nrOfExamples, int folds) {
322 int[] splits = new int[folds];
323 for(int i=1; i<=folds; i++) {
324 // we always round up to the next integer
325 splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds);
326 }
327 return splits;
328 }
329
330 public static String statOutput(DecimalFormat df, Stat stat, String unit) {
331 String str = "av. " + df.format(stat.getMean()) + unit;
332 str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; ";
333 str += "min " + df.format(stat.getMin()) + unit + "; ";
334 str += "max " + df.format(stat.getMax()) + unit + ")";
335 return str;
336 }
337
338 public Stat getAccuracy() {
339 return accuracy;
340 }
341
342 public Stat getLength() {
343 return length;
344 }
345
346 public Stat getRuntime() {
347 return runtime;
348 }
349
350 }