001/**
002 * Portions Copyright 2001 Sun Microsystems, Inc.
003 * Portions Copyright 1999-2001 Language Technologies Institute, 
004 * Carnegie Mellon University.
005 * All Rights Reserved.  Use is subject to license terms.
006 * 
007 * See the file "license.terms" for information on usage and
008 * redistribution of this file, and for a DISCLAIMER OF ALL 
009 * WARRANTIES.
010 */
011package com.sun.speech.freetts.cart;
012
013import java.io.BufferedReader;
014import java.io.DataInputStream;
015import java.io.DataOutputStream;
016import java.io.IOException;
017import java.io.InputStreamReader;
018import java.io.PrintWriter;
019import java.net.URL;
020import java.nio.ByteBuffer;
021import java.util.StringTokenizer;
022import java.util.logging.Level;
023import java.util.logging.Logger;
024import java.util.regex.Pattern;
025
026import com.sun.speech.freetts.Item;
027import com.sun.speech.freetts.PathExtractor;
028import com.sun.speech.freetts.PathExtractorImpl;
029import com.sun.speech.freetts.util.Utilities;
030
031/**
032 * Implementation of a Classification and Regression Tree (CART) that is
033 * used more like a binary decision tree, with each node containing a
034 * decision or a final value.  The decision nodes in the CART trees
035 * operate on an Item and have the following format:
036 *
037 * <pre>
038 *   NODE feat operand value qfalse 
039 * </pre>
040 *
041 * <p>Where <code>feat</code> is an string that represents a feature
042 * to pass to the <code>findFeature</code> method of an item.
043 *
044 * <p>The <code>value</code> represents the value to be compared against
045 * the feature obtained from the item via the <code>feat</code> string.
046 * The <code>operand</code> is the operation to do the comparison.  The
047 * available operands are as follows:
048 *
049 * <ul>
050 *   <li>&lt; - the feature is less than value 
051 *   <li>= - the feature is equal to the value 
052 *   <li>> - the feature is greater than the value 
053 *   <li>MATCHES - the feature matches the regular expression stored in value 
054 *   <li>IN - [[[TODO: still guessing because none of the CART's in
055 *     Flite seem to use IN]]] the value is in the list defined by the
056 *     feature.
057 * </ul>
058 *
059 * <p>[[[TODO: provide support for the IN operator.]]]
060 *
061 * <p>For &lt; and >, this CART coerces the value and feature to
062 * float's. For =, this CART coerces the value and feature to string and
063 * checks for string equality. For MATCHES, this CART uses the value as a
064 * regular expression and compares the obtained feature to that.
065 *
066 * <p>A CART is represented by an array in this implementation. The
067 * <code>qfalse</code> value represents the index of the array to go to if
068 * the comparison does not match. In this implementation, qtrue index
069 * is always implied, and represents the next element in the
070 * array. The root node of the CART is the first element in the array.
071 *
072 * <p>The interpretations always start at the root node of the CART
073 * and continue until a final node is found.  The final nodes have the
074 * following form:
075 *
076 * <pre>
077 *   LEAF value
078 * </pre>
079 *
080 * <p>Where <code>value</code> represents the value of the node.
081 * Reaching a final node indicates the interpretation is over and the
082 * value of the node is the interpretation result.
083 */
084public class CARTImpl implements CART {
085    /** Logger instance. */
086    private static final Logger LOGGER =
087        Logger.getLogger(CARTImpl.class.getName());
088    /**
089     * Entry in file represents the total number of nodes in the
090     * file.  This should be at the top of the file.  The format
091     * should be "TOTAL n" where n is an integer value.
092     */
093    final static String TOTAL = "TOTAL";
094
095    /**
096     * Entry in file represents a node.  The format should be
097     * "NODE feat op val f" where 'feat' represents a feature, op
098     * represents an operand, val is the value, and f is the index
099     * of the node to go to is there isn't a match.
100     */
101    final static String NODE = "NODE";
102
103    /**
104     * Entry in file represents a final node.  The format should be
105     * "LEAF val" where val represents the value.
106     */
107    final static String LEAF = "LEAF";
108
109    /**
110     * OPERAND_MATCHES
111     */
112    final static String OPERAND_MATCHES = "MATCHES";
113
114
115    /**
116     * The CART. Entries can be DecisionNode or LeafNode.  An
117     * ArrayList could be used here -- I chose not to because I
118     * thought it might be quicker to avoid dealing with the dynamic
119     * resizing.
120     */
121    Node[] cart = null;
122    
123    /**
124     * The number of nodes in the CART.
125     */
126    transient int curNode = 0;
127
128    /**
129     * Creates a new CART by reading from the given URL.
130     *
131     * @param url the location of the CART data
132     *
133     * @throws IOException if errors occur while reading the data
134     */ 
135    public CARTImpl(URL url) throws IOException {
136        BufferedReader reader;
137        String line;
138
139        reader = new BufferedReader(new InputStreamReader(url.openStream()));
140        line = reader.readLine();
141        while (line != null) {
142            if (!line.startsWith("***")) {
143                parseAndAdd(line);
144            }
145            line = reader.readLine();
146        }
147        reader.close();
148    }
149
150    /**
151     * Creates a new CART by reading from the given reader.
152     *
153     * @param reader the source of the CART data
154     * @param nodes the number of nodes to read for this cart
155     *
156     * @throws IOException if errors occur while reading the data
157     */ 
158    public CARTImpl(BufferedReader reader, int nodes) throws IOException {
159        this(nodes);
160        String line;
161        for (int i = 0; i < nodes; i++) {
162            line = reader.readLine();
163            if (!line.startsWith("***")) {
164                parseAndAdd(line);
165            }
166        }
167    }
168    
169    /**
170     * Creates a new CART that will be populated with nodes later.
171     *
172     * @param numNodes the number of nodes
173     */
174    private CARTImpl(int numNodes) {
175        cart = new Node[numNodes];
176    }
177
178    /**
179     * Dumps this CART to the output stream.
180     *
181     * @param os the output stream
182     *
183     * @throws IOException if an error occurs during output
184     */
185    public void dumpBinary(DataOutputStream os) throws IOException {
186        os.writeInt(cart.length);
187        for (int i = 0; i < cart.length; i++) {
188            cart[i].dumpBinary(os);
189        }
190    }
191
192    /**
193     * Dump the CART tree as a dot file.
194     * <p>
195     * The dot tool is part of the graphviz distribution at
196     * <a href="http://www.graphviz.org/">http://www.graphviz.org/</a>.
197     * If installed, call it as "dot -O -Tpdf *.dot" from the console to
198     * generate pdfs.
199     * </p>
200     * @param out The PrintWriter to write to.
201     */
202    public void dumpDot(PrintWriter out) {
203        out.write("digraph \"" + "CART Tree" + "\" {\n");
204        out.write("rankdir = LR\n");
205
206        for (Node n : cart) {
207            out.println("\tnode" + Math.abs(n.hashCode()) + " [ label=\""
208                    + n.toString() + "\", color=" + dumpDotNodeColor(n)
209                    + ", shape=" + dumpDotNodeShape(n) + " ]\n");
210            if (n instanceof DecisionNode) {
211                DecisionNode dn = (DecisionNode) n;
212                if (dn.qtrue < cart.length && cart[dn.qtrue] != null) {
213                    out.write("\tnode" + Math.abs(n.hashCode()) + " -> node"
214                            + Math.abs(cart[dn.qtrue].hashCode()) + " [ label="
215                            + "TRUE" + " ]\n");
216                }
217                if (dn.qfalse < cart.length && cart[dn.qfalse] != null) {
218                    out.write("\tnode" + Math.abs(n.hashCode()) + " -> node"
219                            + Math.abs(cart[dn.qfalse].hashCode())
220                            + " [ label=" + "FALSE" + " ]\n");
221                }
222            }
223        }
224
225        out.write("}\n");
226        out.close();
227    }
228
229    protected String dumpDotNodeColor(Node n) {
230        if (n instanceof LeafNode) {
231            return "green";
232        }
233        return "red";
234    }
235
236    protected String dumpDotNodeShape(Node n) {
237        return "box";
238    }
239
240    /**
241     * Loads a CART from the input byte buffer.
242     *
243     * @param bb the byte buffer
244     *
245     * @return the CART
246     *
247     * @throws IOException if an error occurs during output
248     *
249     * Note that cart nodes are really saved as strings that
250     * have to be parsed.
251     */
252    public static CART loadBinary(ByteBuffer bb) throws IOException {
253        int numNodes = bb.getInt();
254        CARTImpl cart = new CARTImpl(numNodes);
255
256        for (int i = 0; i < numNodes; i++) {
257            String nodeCreationLine = Utilities.getString(bb);
258            cart.parseAndAdd(nodeCreationLine);
259        }
260        return cart;
261    }
262
263    /**
264     * Loads a CART from the input stream.
265     *
266     * @param is the input stream
267     *
268     * @return the CART
269     *
270     * @throws IOException if an error occurs during output
271     *
272     * Note that cart nodes are really saved as strings that
273     * have to be parsed.
274     */
275    public static CART loadBinary(DataInputStream is) throws IOException {
276        int numNodes = is.readInt();
277        CARTImpl cart = new CARTImpl(numNodes);
278
279        for (int i = 0; i < numNodes; i++) {
280            String nodeCreationLine = Utilities.getString(is);
281            cart.parseAndAdd(nodeCreationLine);
282        }
283        return cart;
284    }
285    
286    /**
287     * Creates a node from the given input line and add it to the CART.
288     * It expects the TOTAL line to come before any of the nodes.
289     *
290     * @param line a line of input to parse
291     */
292    protected void parseAndAdd(String line) {
293        StringTokenizer tokenizer = new StringTokenizer(line," ");
294        String type = tokenizer.nextToken();        
295        if (type.equals(LEAF) || type.equals(NODE)) {
296            cart[curNode] = getNode(type, tokenizer, curNode);
297            cart[curNode].setCreationLine(line);
298            curNode++;
299        } else if (type.equals(TOTAL)) {
300            cart = new Node[Integer.parseInt(tokenizer.nextToken())];
301            curNode = 0;
302        } else {
303            throw new Error("Invalid CART type: " + type);
304        }
305    }
306
307    /**
308     * Gets the node based upon the type and tokenizer.
309     *
310     * @param type <code>NODE</code> or <code>LEAF</code>
311     * @param tokenizer the StringTokenizer containing the data to get
312     * @param currentNode the index of the current node we're looking at
313     *
314     * @return the node
315     */
316    protected Node getNode(String type,
317                           StringTokenizer tokenizer,
318                           int currentNode) {
319        if (type.equals(NODE)) {
320            String feature = tokenizer.nextToken();
321            String operand = tokenizer.nextToken();
322            Object value = parseValue(tokenizer.nextToken());
323            int qfalse = Integer.parseInt(tokenizer.nextToken());
324            if (operand.equals(OPERAND_MATCHES)) {
325                return new MatchingNode(feature,
326                                        value.toString(),
327                                        currentNode + 1,
328                                        qfalse);
329            } else {
330                return new ComparisonNode(feature,
331                                          value,
332                                          operand,
333                                          currentNode + 1,
334                                          qfalse);
335            }
336        } else if (type.equals(LEAF)) {
337            return new LeafNode(parseValue(tokenizer.nextToken()));
338        }
339
340        return null;
341    }
342
343    /**
344     * Coerces a string into a value.
345     *
346     * @param string of the form "type(value)"; for example, "Float(2.3)"
347     *
348     * @return the value
349     */
350    protected Object parseValue(String string) {
351        int openParen = string.indexOf("(");
352        String type = string.substring(0,openParen);
353        String value = string.substring(openParen + 1, string.length() - 1);
354        if (type.equals("String")) {
355            return value;
356        } else if (type.equals("Float")) {
357            return new Float(Float.parseFloat(value));
358        } else if (type.equals("Integer")) {
359            return new Integer(Integer.parseInt(value));
360        } else if (type.equals("List")) {
361            StringTokenizer tok = new StringTokenizer(value, ",");
362            int size = tok.countTokens();
363
364            int[] values = new int[size];
365            for (int i = 0; i < size; i++) {
366                float fval = Float.parseFloat(tok.nextToken());
367                values[i] = Math.round(fval);
368            }
369            return values;
370        } else {
371            throw new Error("Unknown type: " + type);
372        }
373    }
374    
375    /**
376     * Passes the given item through this CART and returns the
377     * interpretation.
378     *
379     * @param item the item to analyze
380     *
381     * @return the interpretation
382     */
383    public Object interpret(Item item) {
384        int nodeIndex = 0;
385        DecisionNode decision;
386
387        while (!(cart[nodeIndex] instanceof LeafNode)) {
388            decision = (DecisionNode) cart[nodeIndex];
389            nodeIndex = decision.getNextNode(item);
390        }
391        if (LOGGER.isLoggable(Level.FINER)) {
392            LOGGER.finer("LEAF " + cart[nodeIndex].getValue());
393        }
394        return ((LeafNode) cart[nodeIndex]).getValue();
395    }
396
397    /**
398     * A node for the CART.
399     */
400    static abstract class Node {
401        /**
402         * The value of this node.
403         */
404        protected Object value;
405        private String creationLine;
406
407        /**
408         * Create a new Node with the given value.
409         */
410        public Node(Object value) {
411            this.value = value;
412        }
413    
414        /**
415         * Get the value.
416         */
417        public Object getValue() {
418            return value;
419        }
420
421        /**
422         * Return a string representation of the type of the value.
423         */
424        public String getValueString() {
425            if (value == null) {
426                return "NULL()";
427            } else if (value instanceof String) {
428                return "String(" + value.toString() + ")";
429            } else if (value instanceof Float) {
430                return "Float(" + value.toString() + ")";
431            } else if (value instanceof Integer) {
432                return "Integer(" + value.toString() + ")";
433            } else {
434                return value.getClass().toString() + "(" + value.toString() + ")";
435            }
436        }    
437
438        /**
439         * sets the line of text used to create this node.
440         * @param line the creation line
441         */
442        public void setCreationLine(String line) {
443            creationLine = line;
444        }
445
446        /**
447         * Dumps the binary form of this node.
448         * @param os the output stream to output the node on
449         * @throws IOException if an IO error occurs
450         */
451        final public void dumpBinary(DataOutputStream os) throws IOException {
452            Utilities.outString(os, creationLine);
453        }
454    }
455
456    /**
457     * A decision node that determines the next Node to go to in the CART.
458     */
459    abstract static class DecisionNode extends Node {
460        /**
461         * The feature used to find a value from an Item.
462         */
463        private PathExtractor path;
464
465        /**
466         * Index of Node to go to if the comparison doesn't match.
467         */
468        protected int qfalse;
469
470        /**
471         * Index of Node to go to if the comparison matches.
472         */
473        protected int qtrue;
474
475        /**
476         * The feature used to find a value from an Item.
477         */
478        public String getFeature() {
479            return path.toString();
480        }
481
482
483        /**
484         * Find the feature associated with this DecisionNode
485         * and the given item
486         * @param item the item to start from
487         * @return the object representing the feature
488         */
489        public Object findFeature(Item item) {
490            return path.findFeature(item);
491        }
492
493
494        /**
495         * Returns the next node based upon the
496         * descision determined at this node
497         * @param item the current item.
498         * @return the index of the next node
499         */
500        public final int getNextNode(Item item) {
501            return getNextNode(findFeature(item));
502        }
503
504        /**
505         * Create a new DecisionNode.
506         * @param feature the string used to get a value from an Item
507         * @param value the value to compare to
508         * @param qtrue the Node index to go to if the comparison matches
509         * @param qfalse the Node machine index to go to upon no match
510         */
511        public DecisionNode(String feature,
512                            Object value,
513                            int qtrue,
514                            int qfalse) {
515            super(value);
516            this.path = new PathExtractorImpl(feature, true);
517            this.qtrue = qtrue;
518            this.qfalse = qfalse;
519        }
520    
521        /**
522         * Get the next Node to go to in the CART.  The return
523         * value is an index in the CART.
524         */
525        abstract public int getNextNode(Object val);
526    }
527
528    /**
529     * A decision Node that compares two values.
530     */
531    static class ComparisonNode extends DecisionNode {
532        /**
533         * LESS_THAN
534         */
535        final static String LESS_THAN = "<";
536    
537        /**
538         * EQUALS
539         */
540        final static String EQUALS = "=";
541    
542        /**
543         * GREATER_THAN
544         */
545        final static String GREATER_THAN = ">";
546    
547        /**
548         * The comparison type.  One of LESS_THAN, GREATER_THAN, or
549         *  EQUAL_TO.
550         */
551        String comparisonType;
552
553        /**
554         * Create a new ComparisonNode with the given values.
555         * @param feature the string used to get a value from an Item
556         * @param value the value to compare to
557         * @param comparisonType one of LESS_THAN, EQUAL_TO, or GREATER_THAN
558         * @param qtrue the Node index to go to if the comparison matches
559         * @param qfalse the Node index to go to upon no match
560         */
561        public ComparisonNode(String feature,
562                              Object value,
563                              String comparisonType,
564                              int qtrue,
565                              int qfalse) {
566            super(feature, value, qtrue, qfalse);
567            if (!comparisonType.equals(LESS_THAN)
568                && !comparisonType.equals(EQUALS)
569                && !comparisonType.equals(GREATER_THAN)) {
570                throw new Error("Invalid comparison type: " + comparisonType);
571            } else {
572                this.comparisonType = comparisonType;
573            }
574        }
575
576        /**
577         * Compare the given value and return the appropriate Node index.
578         * IMPLEMENTATION NOTE:  LESS_THAN and GREATER_THAN, the Node's
579         * value and the value passed in are converted to floating point
580         * values.  For EQUAL, the Node's value and the value passed in
581         * are treated as String compares.  This is the way of Flite, so
582         * be it Flite.
583         * @param val the value to compare
584         */
585        public int getNextNode(Object val) {
586            boolean yes = false;
587            int ret;
588
589            if (comparisonType.equals(LESS_THAN)
590                || comparisonType.equals(GREATER_THAN)) {
591                float cart_fval;
592                float fval;
593                if (value instanceof Float) {
594                    cart_fval = ((Float) value).floatValue();
595                } else {
596                    cart_fval = Float.parseFloat(value.toString());
597                }
598                if (val instanceof Float) {
599                    fval = ((Float) val).floatValue();
600                } else {
601                    fval = Float.parseFloat(val.toString());
602                }
603                if (comparisonType.equals(LESS_THAN)) {
604                    yes = (fval < cart_fval);
605                } else {
606                    yes =  (fval > cart_fval);
607                }
608            } else { // comparisonType = "="
609                String sval = val.toString();
610                String cart_sval = value.toString();
611                yes = sval.equals(cart_sval);
612            }
613            if (yes) {
614                ret = qtrue;
615            } else {
616                ret = qfalse;
617            }
618
619            if (LOGGER.isLoggable(Level.FINER)) {
620                LOGGER.finer(trace(val, yes, ret));
621            }
622
623            return ret;
624        }
625
626        private String trace(Object value, boolean match, int next) {
627            return
628                "NODE " + getFeature() + " ["
629                + value + "] " 
630                + comparisonType + " [" 
631                + getValue() + "] "
632                + (match ? "Yes" : "No") + " next " +
633                    next;
634        }
635
636        /**
637         * Get a string representation of this Node.
638         */
639        public String toString() {
640            return
641                "NODE " + getFeature() + " "
642                + comparisonType + " "
643                + getValueString() + " "
644                + Integer.toString(qtrue) + " "
645                + Integer.toString(qfalse);
646        }
647    }
648
649    /**
650     * A Node that checks for a regular expression match.
651     */
652    static class MatchingNode extends DecisionNode {
653        Pattern pattern;
654    
655        /**
656         * Create a new MatchingNode with the given values.
657         * @param feature the string used to get a value from an Item
658         * @param regex the regular expression
659         * @param qtrue the Node index to go to if the comparison matches
660         * @param qfalse the Node index to go to upon no match
661         */
662        public MatchingNode(String feature,
663                            String regex,
664                            int qtrue,
665                            int qfalse) {
666            super(feature, regex, qtrue, qfalse);
667            this.pattern = Pattern.compile(regex);
668        }
669
670        /**
671         * Compare the given value and return the appropriate CART index.
672         * @param val the value to compare -- this must be a String
673         */
674        public int getNextNode(Object val) {
675            return pattern.matcher((String) val).matches()
676                ? qtrue
677                : qfalse;
678        }
679
680        /**
681         * Get a string representation of this Node.
682         */
683        public String toString() {
684            StringBuffer buf = new StringBuffer(
685                NODE + " " + getFeature() + " " + OPERAND_MATCHES);
686            buf.append(getValueString() + " ");
687            buf.append(Integer.toString(qtrue) + " ");
688            buf.append(Integer.toString(qfalse));
689            return buf.toString();
690        }
691    }
692
693    /**
694     * The final Node of a CART.  This just a marker class.
695     */
696    static class LeafNode extends Node {
697        /**
698         * Create a new LeafNode with the given value.
699         * @param the value of this LeafNode
700         */
701        public LeafNode(Object value) {
702            super(value);
703        }
704
705        /**
706         * Get a string representation of this Node.
707         */
708        public String toString() {
709            return "LEAF " + getValueString();
710        }
711    }
712}
713