001 package org.maltparser.parser.guide.decision;
002
003 import java.lang.reflect.Constructor;
004 import java.lang.reflect.InvocationTargetException;
005 import java.util.HashMap;
006
007 import org.maltparser.core.exception.MaltChainedException;
008 import org.maltparser.core.feature.FeatureModel;
009 import org.maltparser.core.feature.FeatureVector;
010 import org.maltparser.core.syntaxgraph.DependencyStructure;
011 import org.maltparser.parser.DependencyParserConfig;
012 import org.maltparser.parser.guide.ClassifierGuide;
013 import org.maltparser.parser.guide.GuideException;
014 import org.maltparser.parser.guide.instance.AtomicModel;
015 import org.maltparser.parser.guide.instance.DecisionTreeModel;
016 import org.maltparser.parser.guide.instance.FeatureDivideModel;
017 import org.maltparser.parser.guide.instance.InstanceModel;
018 import org.maltparser.parser.history.action.GuideDecision;
019 import org.maltparser.parser.history.action.MultipleDecision;
020 import org.maltparser.parser.history.action.SingleDecision;
021 import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision;
022 /**
023 *
024 * @author Johan Hall
025 * @since 1.1
026 **/
027 public class BranchedDecisionModel implements DecisionModel {
028 private ClassifierGuide guide;
029 private String modelName;
030 private FeatureModel featureModel;
031 private InstanceModel instanceModel;
032 private int decisionIndex;
033 private DecisionModel parentDecisionModel;
034 private HashMap<Integer,DecisionModel> children;
035 private String branchedDecisionSymbols;
036
037 public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException {
038 this.branchedDecisionSymbols = "";
039 setGuide(guide);
040 setFeatureModel(featureModel);
041 setDecisionIndex(0);
042 setModelName("bdm"+decisionIndex);
043 setParentDecisionModel(null);
044 }
045
046 public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException {
047 if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) {
048 this.branchedDecisionSymbols = branchedDecisionSymbol;
049 } else {
050 this.branchedDecisionSymbols = "";
051 }
052 setGuide(guide);
053 setParentDecisionModel(parentDecisionModel);
054 setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1);
055 setFeatureModel(parentDecisionModel.getFeatureModel());
056 if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) {
057 setModelName("bdm"+decisionIndex+branchedDecisionSymbols);
058 } else {
059 setModelName("bdm"+decisionIndex);
060 }
061 this.parentDecisionModel = parentDecisionModel;
062 }
063
064 public void updateFeatureModel() throws MaltChainedException {
065 featureModel.update();
066 }
067
068 public void updateCardinality() throws MaltChainedException {
069 featureModel.updateCardinality();
070 }
071
072
073 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
074 if (instanceModel != null) {
075 instanceModel.finalizeSentence(dependencyGraph);
076 }
077 if (children != null) {
078 for (DecisionModel child : children.values()) {
079 child.finalizeSentence(dependencyGraph);
080 }
081 }
082 }
083
084 public void noMoreInstances() throws MaltChainedException {
085 if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
086 throw new GuideException("The decision model could not create it's model. ");
087 }
088 featureModel.updateCardinality();
089 if (instanceModel != null) {
090 instanceModel.noMoreInstances();
091 instanceModel.train();
092 }
093 if (children != null) {
094 for (DecisionModel child : children.values()) {
095 child.noMoreInstances();
096 }
097 }
098 }
099
100 public void terminate() throws MaltChainedException {
101 if (instanceModel != null) {
102 instanceModel.terminate();
103 instanceModel = null;
104 }
105 if (children != null) {
106 for (DecisionModel child : children.values()) {
107 child.terminate();
108 }
109 }
110 }
111
112 public void addInstance(GuideDecision decision) throws MaltChainedException {
113 if (decision instanceof SingleDecision) {
114 throw new GuideException("A branched decision model expect more than one decisions. ");
115 }
116 updateFeatureModel();
117 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
118 if (instanceModel == null) {
119 initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
120 }
121
122 instanceModel.addInstance(singleDecision);
123 if (decisionIndex+1 < decision.numberOfDecisions()) {
124 if (singleDecision.continueWithNextDecision()) {
125 if (children == null) {
126 children = new HashMap<Integer,DecisionModel>();
127 }
128 DecisionModel child = children.get(singleDecision.getDecisionCode());
129 if (child == null) {
130 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
131 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
132 children.put(singleDecision.getDecisionCode(), child);
133 }
134 child.addInstance(decision);
135 }
136 }
137 }
138
139 public boolean predict(GuideDecision decision) throws MaltChainedException {
140 if (decision instanceof SingleDecision) {
141 throw new GuideException("A branched decision model expect more than one decisions. ");
142 }
143 updateFeatureModel();
144 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
145 if (instanceModel == null) {
146 initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
147 }
148 instanceModel.predict(singleDecision);
149 if (decisionIndex+1 < decision.numberOfDecisions()) {
150 if (singleDecision.continueWithNextDecision()) {
151 if (children == null) {
152 children = new HashMap<Integer,DecisionModel>();
153 }
154 DecisionModel child = children.get(singleDecision.getDecisionCode());
155 if (child == null) {
156 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
157 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
158 children.put(singleDecision.getDecisionCode(), child);
159 }
160 child.predict(decision);
161 }
162 }
163
164 return true;
165 }
166
167 public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException {
168 if (decision instanceof SingleDecision) {
169 throw new GuideException("A branched decision model expect more than one decisions. ");
170 }
171 updateFeatureModel();
172 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
173 if (instanceModel == null) {
174 initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
175 }
176 FeatureVector fv = instanceModel.predictExtract(singleDecision);
177 if (decisionIndex+1 < decision.numberOfDecisions()) {
178 if (singleDecision.continueWithNextDecision()) {
179 if (children == null) {
180 children = new HashMap<Integer,DecisionModel>();
181 }
182 DecisionModel child = children.get(singleDecision.getDecisionCode());
183 if (child == null) {
184 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
185 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
186 children.put(singleDecision.getDecisionCode(), child);
187 }
188 child.predictExtract(decision);
189 }
190 }
191
192 return fv;
193 }
194
195 public FeatureVector extract() throws MaltChainedException {
196 updateFeatureModel();
197 return instanceModel.extract(); // TODO handle many feature vectors
198 }
199
200 public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException {
201 if (decision instanceof SingleDecision) {
202 throw new GuideException("A branched decision model expect more than one decisions. ");
203 }
204
205 boolean success = false;
206 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
207 if (decisionIndex+1 < decision.numberOfDecisions()) {
208 if (singleDecision.continueWithNextDecision()) {
209 if (children == null) {
210 children = new HashMap<Integer,DecisionModel>();
211 }
212 DecisionModel child = children.get(singleDecision.getDecisionCode());
213 if (child != null) {
214 success = child.predictFromKBestList(decision);
215 }
216
217 }
218 }
219 if (!success) {
220 success = singleDecision.updateFromKBestList();
221 if (decisionIndex+1 < decision.numberOfDecisions()) {
222 if (singleDecision.continueWithNextDecision()) {
223 if (children == null) {
224 children = new HashMap<Integer,DecisionModel>();
225 }
226 DecisionModel child = children.get(singleDecision.getDecisionCode());
227 if (child == null) {
228 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
229 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
230 children.put(singleDecision.getDecisionCode(), child);
231 }
232 child.predict(decision);
233 }
234 }
235 }
236 return success;
237 }
238
239
240 public ClassifierGuide getGuide() {
241 return guide;
242 }
243
244 public String getModelName() {
245 return modelName;
246 }
247
248 public FeatureModel getFeatureModel() {
249 return featureModel;
250 }
251
252 public int getDecisionIndex() {
253 return decisionIndex;
254 }
255
256 public DecisionModel getParentDecisionModel() {
257 return parentDecisionModel;
258 }
259
260 private void setFeatureModel(FeatureModel featureModel) {
261 this.featureModel = featureModel;
262 }
263
264 private void setDecisionIndex(int decisionIndex) {
265 this.decisionIndex = decisionIndex;
266 }
267
268 private void setParentDecisionModel(DecisionModel parentDecisionModel) {
269 this.parentDecisionModel = parentDecisionModel;
270 }
271
272 private void setModelName(String modelName) {
273 this.modelName = modelName;
274 }
275
276 private void setGuide(ClassifierGuide guide) {
277 this.guide = guide;
278 }
279
280
281 private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException {
282 Class<?> decisionModelClass = null;
283 if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) {
284 decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class;
285 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) {
286 decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class;
287 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) {
288 decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class;
289 }
290
291 if (decisionModelClass == null) {
292 throw new GuideException("Could not find an appropriate decision model for the relation to the next decision");
293 }
294
295 try {
296 Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class,
297 java.lang.String.class };
298 Object[] arguments = new Object[3];
299 arguments[0] = getGuide();
300 arguments[1] = this;
301 arguments[2] = branchedDecisionSymbol;
302 Constructor<?> constructor = decisionModelClass.getConstructor(argTypes);
303 return (DecisionModel)constructor.newInstance(arguments);
304 } catch (NoSuchMethodException e) {
305 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
306 } catch (InstantiationException e) {
307 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
308 } catch (IllegalAccessException e) {
309 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
310 } catch (InvocationTargetException e) {
311 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
312 }
313 }
314
315 private void initInstanceModel(String subModelName) throws MaltChainedException {
316 FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName);
317 if (fv == null) {
318 fv = featureModel.getFeatureVector(subModelName);
319 }
320 if (fv == null) {
321 fv = featureModel.getMainFeatureVector();
322 }
323
324 DependencyParserConfig c = guide.getConfiguration();
325
326 // if (c.getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes") ||
327 // (c.getOptionValue("guide", "tree_split_columns")!=null &&
328 // c.getOptionValue("guide", "tree_split_columns").toString().length() > 0) ||
329 // (c.getOptionValue("guide", "tree_split_structures")!=null &&
330 // c.getOptionValue("guide", "tree_split_structures").toString().length() > 0)) {
331 // instanceModel = new DecisionTreeModel(fv, this);
332 // }else
333 if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) {
334 instanceModel = new AtomicModel(-1, fv, this);
335 } else {
336 instanceModel = new FeatureDivideModel(fv, this);
337 }
338 }
339
340 public String toString() {
341 final StringBuilder sb = new StringBuilder();
342 sb.append(modelName + ", ");
343 for (DecisionModel model : children.values()) {
344 sb.append(model.toString() + ", ");
345 }
346 return sb.toString();
347 }
348 }