Main Content

predict

Predict labels using classification tree model

    Description

    label = predict(tree,X) returns a vector of predicted class labels for the predictor data in the table or matrix X, based on the trained classification tree tree.

    example

    label = predict(tree,X,Subtrees=subtrees) also prunes tree to the level specified by subtrees, before predicting labels.

    example

    [label,score,node,cnum] = predict(___) also returns the following, using any of the input argument combinations in the previous syntaxes:

    • A matrix of classification scores (score) indicating the likelihood that a label comes from a particular class. For classification trees, scores are posterior probabilities. For each observation in X, the predicted class label corresponds to the minimum expected misclassification cost among all classes.

    • A vector of predicted node numbers for the classification (node).

    • A vector of predicted class numbers for the classification (cnum).

    example

    Examples

    collapse all

    Examine predictions for a few rows in a data set left out of training.

    Load Fisher's iris data set.

    load fisheriris

    Partition the data into training (50%) and validation (50%) sets.

    n = size(meas,1);
    rng(1) % For reproducibility
    idxTrn = false(n,1);
    idxTrn(randsample(n,round(0.5*n))) = true;
    idxVal = idxTrn == false;                 

    Grow a classification tree using the training set.

    Mdl = fitctree(meas(idxTrn,:),species(idxTrn));

    Predict labels for the validation data, and display several predicted labels. Count the number of misclassified observations.

    label = predict(Mdl,meas(idxVal,:));
    label(randsample(numel(label),5))
    ans = 5×1 cell
        {'setosa'    }
        {'setosa'    }
        {'setosa'    }
        {'virginica' }
        {'versicolor'}
    
    
    numMisclass = sum(~strcmp(label,species(idxVal)))
    numMisclass = 
    3
    

    The software misclassifies three out-of-sample observations.

    Load Fisher's iris data set.

    load fisheriris

    Partition the data into training (50%) and validation (50%) sets.

    n = size(meas,1);
    rng(1) % For reproducibility
    idxTrn = false(n,1);
    idxTrn(randsample(n,round(0.5*n))) = true;
    idxVal = idxTrn == false;

    Grow a classification tree using the training set, and then view it.

    Mdl = fitctree(meas(idxTrn,:),species(idxTrn));
    view(Mdl,"Mode","graph")

    Figure Classification tree viewer contains an axes object and other objects of type uimenu, uicontrol. The axes object contains 18 objects of type line, text. One or more of the lines displays its values using only markers

    The resulting tree has four levels.

    Estimate posterior probabilities for the test set using subtrees pruned to levels 1 and 3. Display several posterior probabilities.

    [~,Posterior] = predict(Mdl,meas(idxVal,:), ...
        Subtrees=[1 3]);
    Mdl.ClassNames
    ans = 3×1 cell
        {'setosa'    }
        {'versicolor'}
        {'virginica' }
    
    
    Posterior(randsample(size(Posterior,1),5),:,:)
    ans = 
    ans(:,:,1) =
    
        1.0000         0         0
        1.0000         0         0
        1.0000         0         0
             0         0    1.0000
             0    0.8571    0.1429
    
    
    ans(:,:,2) =
    
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
        0.3733    0.3200    0.3067
    
    

    The elements of Posterior are class posterior probabilities:

    • Rows correspond to observations in the validation set.

    • Columns correspond to the classes as listed in Mdl.ClassNames.

    • Pages correspond to the subtrees.

    The subtree pruned to level 1 is more sure of its predictions than the subtree pruned to level 3 (that is, the root node).

    Input Arguments

    collapse all

    Trained classification tree, specified as a ClassificationTree model object trained with fitctree, or a CompactClassificationTree model object created with compact.

    Predictor data to be classified, specified as a numeric matrix or a table.

    Each row of X corresponds to one observation, and each column corresponds to one variable.

    For a numeric matrix:

    • The variables that make up the columns of X must have the same order as the predictor variables used to train tree.

    • If you train tree using a table (for example, Tbl), then X can be a numeric matrix if Tbl contains all numeric predictor variables. To treat numeric predictors in Tbl as categorical during training, identify categorical predictors using the CategoricalPredictors name-value argument of fitctree. If Tbl contains heterogeneous predictor variables (for example, numeric and categorical data types) and X is a numeric matrix, then predict issues an error.

    For a table:

    • predict does not support multicolumn variables or cell arrays other than cell arrays of character vectors.

    • If you train tree using a table (for example, Tbl), then all predictor variables in X must have the same variable names and data types as those used to train tree (stored in tree.PredictorNames). However, the column order of X does not need to correspond to the column order of Tbl. Tbl and X can contain additional variables (response variables, observation weights, and so on), but predict ignores them.

    • If you train tree using a numeric matrix, then the predictor names in tree.PredictorNames and corresponding predictor variable names in X must be the same. To specify predictor names during training, use the PredictorNames name-value argument of fitctree. All predictor variables in X must be numeric vectors. X can contain additional variables (response variables, observation weights, and so on), but predict ignores them.

    Data Types: table | double | single

    Pruning level, specified as a vector of nonnegative integers in ascending order or "all".

    If you specify a vector, then all elements must be at least 0 and at most max(tree.PruneList). 0 indicates the full, unpruned tree, and max(tree.PruneList) indicates the completely pruned tree (that is, just the root node).

    If you specify "all", then predict operates on all subtrees (that is, the entire pruning sequence). This specification is equivalent to using 0:max(tree.PruneList).

    predict prunes tree to each level specified by subtrees, and then estimates the corresponding output arguments. The size of subtrees determines the size of some output arguments.

    For the function to invoke subtrees, the properties PruneList and PruneAlpha of tree must be nonempty. In other words, grow tree by setting Prune="on" when you use fitctree, or by pruning tree using prune.

    Example: subtrees="all"

    Data Types: single | double | char | string

    Output Arguments

    collapse all

    Predicted class labels, returned as a categorical or character array, logical or numeric vector, or cell array of character vectors. Each entry of label corresponds to the class with the minimal expected cost for the corresponding row of X.

    Suppose subtrees is a numeric vector containing T elements, and X has N rows.

    • If the response data type is char and T = 1, then label is a character matrix containing N rows. Each row contains the predicted label produced by subtrees.

    • If the response data type is char and T > 1, then label is an N-by-T cell array. Column j of label contains the vector of predicted labels produced by subtree subtrees(j).

    • Otherwise, label is an N-by-T array that has the same data type as the response. Column j of label contains the vector of predicted labels produced by subtree subtrees(j). (The software treats string arrays as cell arrays of character vectors.)

    Posterior probabilities, returned as a numeric matrix of size N-by-K, where N is the number of observations (rows) in X, and K is the number of classes (in tree.ClassNames). score(i,j) is the posterior probability that row i in X is of class j in tree.ClassNames.

    If subtrees has T elements, and X has N rows, then score is an N-by-K-by-T array, and node and cnum are N-by-T matrices.

    Node numbers for the predicted classes, returned as a numeric vector. Each entry corresponds to the predicted node in tree for the corresponding row of X.

    Class numbers corresponding to the predicted labels, returned as a numeric vector. Each entry of cnum corresponds to the predicted class number for the corresponding row of X.

    More About

    collapse all

    Algorithms

    predict generates predictions by following the branches of tree until it reaches a leaf node or a missing value. If predict reaches a leaf node, it returns the classification of that node.

    If predict reaches a node with a missing value for a predictor, its behavior depends on the setting of the Surrogate name-value argument when fitctree constructs tree.

    • Surrogate = "off" (default) — predict returns the label with the largest number of training samples that reach the node.

    • Surrogate = "on"predict uses the best surrogate split at the node. If all surrogate split variables with positive predictive measure of association are missing, predict returns the label with the largest number of training samples that reach the node. For a definition, see Predictive Measure of Association.

    Alternative Functionality

    Simulink Block

    To integrate the prediction of a classification tree model into Simulink®, you can use the ClassificationTree Predict block in the Statistics and Machine Learning Toolbox™ library or a MATLAB® Function block with the predict function. For examples, see Predict Class Labels Using ClassificationTree Predict Block and Predict Class Labels Using MATLAB Function Block.

    When deciding which approach to use, consider the following:

    • If you use the Statistics and Machine Learning Toolbox library block, you can use the Fixed-Point Tool (Fixed-Point Designer) to convert a floating-point model to fixed point.

    • Support for variable-size arrays must be enabled for a MATLAB Function block with the predict function.

    • If you use a MATLAB Function block, you can use MATLAB functions for preprocessing or post-processing before or after predictions in the same MATLAB Function block.

    Extended Capabilities

    expand all

    Version History

    Introduced in R2011a