001 package org.maltparser.parser.guide.decision; 002 003 import java.lang.reflect.Constructor; 004 import java.lang.reflect.InvocationTargetException; 005 import java.util.HashMap; 006 007 import org.maltparser.core.exception.MaltChainedException; 008 import org.maltparser.core.feature.FeatureModel; 009 import org.maltparser.core.feature.FeatureVector; 010 import org.maltparser.core.syntaxgraph.DependencyStructure; 011 import org.maltparser.parser.DependencyParserConfig; 012 import org.maltparser.parser.guide.ClassifierGuide; 013 import org.maltparser.parser.guide.GuideException; 014 import org.maltparser.parser.guide.instance.AtomicModel; 015 import org.maltparser.parser.guide.instance.DecisionTreeModel; 016 import org.maltparser.parser.guide.instance.FeatureDivideModel; 017 import org.maltparser.parser.guide.instance.InstanceModel; 018 import org.maltparser.parser.history.action.GuideDecision; 019 import org.maltparser.parser.history.action.MultipleDecision; 020 import org.maltparser.parser.history.action.SingleDecision; 021 import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision; 022 /** 023 * 024 * @author Johan Hall 025 * @since 1.1 026 **/ 027 public class BranchedDecisionModel implements DecisionModel { 028 private ClassifierGuide guide; 029 private String modelName; 030 private FeatureModel featureModel; 031 private InstanceModel instanceModel; 032 private int decisionIndex; 033 private DecisionModel parentDecisionModel; 034 private HashMap<Integer,DecisionModel> children; 035 private String branchedDecisionSymbols; 036 037 public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException { 038 this.branchedDecisionSymbols = ""; 039 setGuide(guide); 040 setFeatureModel(featureModel); 041 setDecisionIndex(0); 042 setModelName("bdm"+decisionIndex); 043 setParentDecisionModel(null); 044 } 045 046 public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException { 047 if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) { 048 this.branchedDecisionSymbols = branchedDecisionSymbol; 049 } else { 050 this.branchedDecisionSymbols = ""; 051 } 052 setGuide(guide); 053 setParentDecisionModel(parentDecisionModel); 054 setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1); 055 setFeatureModel(parentDecisionModel.getFeatureModel()); 056 if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) { 057 setModelName("bdm"+decisionIndex+branchedDecisionSymbols); 058 } else { 059 setModelName("bdm"+decisionIndex); 060 } 061 this.parentDecisionModel = parentDecisionModel; 062 } 063 064 public void updateFeatureModel() throws MaltChainedException { 065 featureModel.update(); 066 } 067 068 public void updateCardinality() throws MaltChainedException { 069 featureModel.updateCardinality(); 070 } 071 072 073 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { 074 if (instanceModel != null) { 075 instanceModel.finalizeSentence(dependencyGraph); 076 } 077 if (children != null) { 078 for (DecisionModel child : children.values()) { 079 child.finalizeSentence(dependencyGraph); 080 } 081 } 082 } 083 084 public void noMoreInstances() throws MaltChainedException { 085 if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 086 throw new GuideException("The decision model could not create it's model. "); 087 } 088 featureModel.updateCardinality(); 089 if (instanceModel != null) { 090 instanceModel.noMoreInstances(); 091 instanceModel.train(); 092 } 093 if (children != null) { 094 for (DecisionModel child : children.values()) { 095 child.noMoreInstances(); 096 } 097 } 098 } 099 100 public void terminate() throws MaltChainedException { 101 if (instanceModel != null) { 102 instanceModel.terminate(); 103 instanceModel = null; 104 } 105 if (children != null) { 106 for (DecisionModel child : children.values()) { 107 child.terminate(); 108 } 109 } 110 } 111 112 public void addInstance(GuideDecision decision) throws MaltChainedException { 113 if (decision instanceof SingleDecision) { 114 throw new GuideException("A branched decision model expect more than one decisions. "); 115 } 116 updateFeatureModel(); 117 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 118 if (instanceModel == null) { 119 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 120 } 121 122 instanceModel.addInstance(singleDecision); 123 if (decisionIndex+1 < decision.numberOfDecisions()) { 124 if (singleDecision.continueWithNextDecision()) { 125 if (children == null) { 126 children = new HashMap<Integer,DecisionModel>(); 127 } 128 DecisionModel child = children.get(singleDecision.getDecisionCode()); 129 if (child == null) { 130 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 131 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 132 children.put(singleDecision.getDecisionCode(), child); 133 } 134 child.addInstance(decision); 135 } 136 } 137 } 138 139 public boolean predict(GuideDecision decision) throws MaltChainedException { 140 if (decision instanceof SingleDecision) { 141 throw new GuideException("A branched decision model expect more than one decisions. "); 142 } 143 updateFeatureModel(); 144 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 145 if (instanceModel == null) { 146 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 147 } 148 instanceModel.predict(singleDecision); 149 if (decisionIndex+1 < decision.numberOfDecisions()) { 150 if (singleDecision.continueWithNextDecision()) { 151 if (children == null) { 152 children = new HashMap<Integer,DecisionModel>(); 153 } 154 DecisionModel child = children.get(singleDecision.getDecisionCode()); 155 if (child == null) { 156 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 157 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 158 children.put(singleDecision.getDecisionCode(), child); 159 } 160 child.predict(decision); 161 } 162 } 163 164 return true; 165 } 166 167 public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException { 168 if (decision instanceof SingleDecision) { 169 throw new GuideException("A branched decision model expect more than one decisions. "); 170 } 171 updateFeatureModel(); 172 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 173 if (instanceModel == null) { 174 initInstanceModel(singleDecision.getTableContainer().getTableContainerName()); 175 } 176 FeatureVector fv = instanceModel.predictExtract(singleDecision); 177 if (decisionIndex+1 < decision.numberOfDecisions()) { 178 if (singleDecision.continueWithNextDecision()) { 179 if (children == null) { 180 children = new HashMap<Integer,DecisionModel>(); 181 } 182 DecisionModel child = children.get(singleDecision.getDecisionCode()); 183 if (child == null) { 184 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 185 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 186 children.put(singleDecision.getDecisionCode(), child); 187 } 188 child.predictExtract(decision); 189 } 190 } 191 192 return fv; 193 } 194 195 public FeatureVector extract() throws MaltChainedException { 196 updateFeatureModel(); 197 return instanceModel.extract(); // TODO handle many feature vectors 198 } 199 200 public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException { 201 if (decision instanceof SingleDecision) { 202 throw new GuideException("A branched decision model expect more than one decisions. "); 203 } 204 205 boolean success = false; 206 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex); 207 if (decisionIndex+1 < decision.numberOfDecisions()) { 208 if (singleDecision.continueWithNextDecision()) { 209 if (children == null) { 210 children = new HashMap<Integer,DecisionModel>(); 211 } 212 DecisionModel child = children.get(singleDecision.getDecisionCode()); 213 if (child != null) { 214 success = child.predictFromKBestList(decision); 215 } 216 217 } 218 } 219 if (!success) { 220 success = singleDecision.updateFromKBestList(); 221 if (decisionIndex+1 < decision.numberOfDecisions()) { 222 if (singleDecision.continueWithNextDecision()) { 223 if (children == null) { 224 children = new HashMap<Integer,DecisionModel>(); 225 } 226 DecisionModel child = children.get(singleDecision.getDecisionCode()); 227 if (child == null) { 228 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 229 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol()); 230 children.put(singleDecision.getDecisionCode(), child); 231 } 232 child.predict(decision); 233 } 234 } 235 } 236 return success; 237 } 238 239 240 public ClassifierGuide getGuide() { 241 return guide; 242 } 243 244 public String getModelName() { 245 return modelName; 246 } 247 248 public FeatureModel getFeatureModel() { 249 return featureModel; 250 } 251 252 public int getDecisionIndex() { 253 return decisionIndex; 254 } 255 256 public DecisionModel getParentDecisionModel() { 257 return parentDecisionModel; 258 } 259 260 private void setFeatureModel(FeatureModel featureModel) { 261 this.featureModel = featureModel; 262 } 263 264 private void setDecisionIndex(int decisionIndex) { 265 this.decisionIndex = decisionIndex; 266 } 267 268 private void setParentDecisionModel(DecisionModel parentDecisionModel) { 269 this.parentDecisionModel = parentDecisionModel; 270 } 271 272 private void setModelName(String modelName) { 273 this.modelName = modelName; 274 } 275 276 private void setGuide(ClassifierGuide guide) { 277 this.guide = guide; 278 } 279 280 281 private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException { 282 Class<?> decisionModelClass = null; 283 if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) { 284 decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class; 285 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) { 286 decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class; 287 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) { 288 decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class; 289 } 290 291 if (decisionModelClass == null) { 292 throw new GuideException("Could not find an appropriate decision model for the relation to the next decision"); 293 } 294 295 try { 296 Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class, 297 java.lang.String.class }; 298 Object[] arguments = new Object[3]; 299 arguments[0] = getGuide(); 300 arguments[1] = this; 301 arguments[2] = branchedDecisionSymbol; 302 Constructor<?> constructor = decisionModelClass.getConstructor(argTypes); 303 return (DecisionModel)constructor.newInstance(arguments); 304 } catch (NoSuchMethodException e) { 305 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 306 } catch (InstantiationException e) { 307 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 308 } catch (IllegalAccessException e) { 309 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 310 } catch (InvocationTargetException e) { 311 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e); 312 } 313 } 314 315 private void initInstanceModel(String subModelName) throws MaltChainedException { 316 FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName); 317 if (fv == null) { 318 fv = featureModel.getFeatureVector(subModelName); 319 } 320 if (fv == null) { 321 fv = featureModel.getMainFeatureVector(); 322 } 323 324 DependencyParserConfig c = guide.getConfiguration(); 325 326 // if (c.getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes") || 327 // (c.getOptionValue("guide", "tree_split_columns")!=null && 328 // c.getOptionValue("guide", "tree_split_columns").toString().length() > 0) || 329 // (c.getOptionValue("guide", "tree_split_structures")!=null && 330 // c.getOptionValue("guide", "tree_split_structures").toString().length() > 0)) { 331 // instanceModel = new DecisionTreeModel(fv, this); 332 // }else 333 if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) { 334 instanceModel = new AtomicModel(-1, fv, this); 335 } else { 336 instanceModel = new FeatureDivideModel(fv, this); 337 } 338 } 339 340 public String toString() { 341 final StringBuilder sb = new StringBuilder(); 342 sb.append(modelName + ", "); 343 for (DecisionModel model : children.values()) { 344 sb.append(model.toString() + ", "); 345 } 346 return sb.toString(); 347 } 348 }