001    /**
002     * Copyright (C) 2007-2008, Jens Lehmann
003     *
004     * This file is part of DL-Learner.
005     * 
006     * DL-Learner is free software; you can redistribute it and/or modify
007     * it under the terms of the GNU General Public License as published by
008     * the Free Software Foundation; either version 3 of the License, or
009     * (at your option) any later version.
010     *
011     * DL-Learner is distributed in the hope that it will be useful,
012     * but WITHOUT ANY WARRANTY; without even the implied warranty of
013     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014     * GNU General Public License for more details.
015     *
016     * You should have received a copy of the GNU General Public License
017     * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018     *
019     */
020    package org.dllearner.algorithms.refinement2;
021    
022    import java.util.List;
023    
024    import org.dllearner.core.configurators.ROLComponent2Configurator;
025    import org.dllearner.core.owl.DatatypeSomeRestriction;
026    import org.dllearner.core.owl.Description;
027    import org.dllearner.core.owl.Negation;
028    import org.dllearner.core.owl.Thing;
029    import org.dllearner.utilities.owl.ConceptComparator;
030    
031    /**
032     * This heuristic combines the following criteria to assign a
033     * double score value to a node:
034     * <ul>
035     * <li>quality/accuracy of a concept (based on the full training set, not
036     *   the negative example coverage as the flexible heuristic)</li>
037     * <li>horizontal expansion</li>
038     * <li>accuracy gain: The heuristic takes into account the accuracy
039     *   difference between a node and its parent. If there is no gain (even
040     *   though we know that the refinement is proper) it is unlikely (although
041     *   not excluded) that the refinement is a necessary path to take towards a
042     *   solution.</li>
043     * </ul> 
044     *
045     * The heuristic has two parameters:
046     * <ul>
047     * <li>expansion penalty factor: describes how much accuracy gain is worth
048     *   an increase of horizontal expansion by one (typical value: 0.01)</li>
049     * <li>gain bonus factor: describes how accuracy gain should be weighted
050     *   versus accuracy itself (typical value: 1.00)</li>
051     * </ul>
052     *   
053     * The value of a node is calculated as follows:
054     * 
055     * <p><code>value = accuracy + gain bonus factor * accuracy gain - expansion penalty
056     * factor * horizontal expansion - node children penalty factor * number of children of node</code></p>
057     * 
058     * <p><code>accuracy = (TP + TN)/(P + N)</code></p>
059     * 
060     * <p><code>
061     * TP = number of true positives (= covered positives)<br />
062     * TN = number of true negatives (= nr of negatives examples - covered negatives)<br />
063     * P = number of positive examples<br />
064     * N = number of negative examples<br />
065     * </code></p>
066     * 
067     * @author Jens Lehmann
068     *
069     */
070    public class MultiHeuristic implements ExampleBasedHeuristic {
071            
072            private ConceptComparator conceptComparator = new ConceptComparator();
073            private ROLComponent2Configurator configurator;
074            
075            // heuristic parameters
076            private double expansionPenaltyFactor = 0.02;
077            private double gainBonusFactor = 0.5;
078            private double nodeChildPenalty = 0.0001; // (use higher values than 0.0001 for simple learning problems);
079            private double startNodeBonus = 0.1; //was 2.0
080            // penalise errors on positive examples harder than on negative examples
081            // (positive weight = 1)
082            private double negativeWeight = 1.0; // was 0.8;
083            
084            // examples
085            private int nrOfNegativeExamples;
086            private int nrOfExamples;
087            
088            @Deprecated
089            public MultiHeuristic(int nrOfPositiveExamples, int nrOfNegativeExamples) {
090                    this.nrOfNegativeExamples = nrOfNegativeExamples;
091                    nrOfExamples = nrOfPositiveExamples + nrOfNegativeExamples;
092    //              this(nrOfPositiveExamples, nrOfNegativeExamples, 0.02, 0.5);
093            }
094            
095            public MultiHeuristic(int nrOfPositiveExamples, int nrOfNegativeExamples, ROLComponent2Configurator configurator) {
096                    this.nrOfNegativeExamples = nrOfNegativeExamples;
097                    nrOfExamples = nrOfPositiveExamples + nrOfNegativeExamples;
098                    this.configurator = configurator;
099                    negativeWeight = configurator.getNegativeWeight();
100                    startNodeBonus = configurator.getStartNodeBonus();
101                    expansionPenaltyFactor = configurator.getExpansionPenaltyFactor();
102            }
103            
104    //      public MultiHeuristic(int nrOfPositiveExamples, int nrOfNegativeExamples, double expansionPenaltyFactor, double gainBonusFactor) {
105    //              this.nrOfNegativeExamples = nrOfNegativeExamples;
106    //              nrOfExamples = nrOfPositiveExamples + nrOfNegativeExamples;
107    //              this.expansionPenaltyFactor = expansionPenaltyFactor;
108    //              this.gainBonusFactor = gainBonusFactor;
109    //      }
110    
111            
112            /* (non-Javadoc)
113             * @see java.util.Comparator#compare(java.lang.Object, java.lang.Object)
114             */
115            public int compare(ExampleBasedNode node1, ExampleBasedNode node2) {
116                    double score1 = getNodeScore(node1);
117                    double score2 = getNodeScore(node2);
118                    double diff = score1 - score2;
119                    if(diff>0)
120                            return 1;
121                    else if(diff<0)
122                            return -1;
123                    else
124                            // we cannot return 0 here otherwise different nodes/concepts with the
125                            // same score may be ignored (not added to a set because an equal element exists)
126                            return conceptComparator.compare(node1.getConcept(), node2.getConcept());
127            }
128    
129            public double getNodeScore(ExampleBasedNode node) {
130                    double accuracy = getWeightedAccuracy(node.getCoveredPositives().size(),node.getCoveredNegatives().size());
131                    ExampleBasedNode parent = node.getParent();
132                    double gain = 0;
133                    if(parent != null) {
134                            double parentAccuracy =  getWeightedAccuracy(parent.getCoveredPositives().size(),parent.getCoveredNegatives().size());
135                            gain = accuracy - parentAccuracy;
136                    } else {
137                            accuracy += startNodeBonus;
138                    }
139                    int he = node.getHorizontalExpansion() - getHeuristicLengthBonus(node.getConcept());
140                    return accuracy + gainBonusFactor * gain - expansionPenaltyFactor * he - nodeChildPenalty * node.getChildren().size();
141            }
142            
143            private double getWeightedAccuracy(int coveredPositives, int coveredNegatives) {
144                    return (coveredPositives + negativeWeight * (nrOfNegativeExamples - coveredNegatives))/(double)nrOfExamples;
145            }
146            
147            public static double getNodeScore(ExampleBasedNode node, int nrOfPositiveExamples, int nrOfNegativeExamples, ROLComponent2Configurator configurator) {
148                    MultiHeuristic multi = new MultiHeuristic(nrOfPositiveExamples, nrOfNegativeExamples, configurator);
149                    return multi.getNodeScore(node);
150            }
151            
152            // this function can be used to give some constructs a length bonus
153            // compared to their syntactic length
154            private int getHeuristicLengthBonus(Description description) {
155                    int bonus = 0;
156                    
157                    // do not count TOP symbols (in particular in ALL r.TOP and EXISTS r.TOP)
158                    // as they provide no extra information
159                    if(description instanceof Thing)
160                            bonus = 1; //2;
161                    
162                    // we put a penalty on negations, because they often overfit
163                    // (TODO: make configurable)
164                    else if(description instanceof Negation) {
165                            bonus = -configurator.getNegationPenalty();
166                    }
167                    
168    //              if(description instanceof BooleanValueRestriction)
169    //                      bonus = -1;
170                    
171                    // some bonus for doubles because they are already penalised by length 3
172                    else if(description instanceof DatatypeSomeRestriction) {
173    //                      System.out.println(description);
174                            bonus = 3; //2;
175                    }
176                    
177                    List<Description> children = description.getChildren();
178                    for(Description child : children) {
179                            bonus += getHeuristicLengthBonus(child);
180                    }
181                    return bonus;
182            }
183    }