Skip to contents

Classification and regression using an oblique decision tree (ODT) in which each node is split by a linear combination of predictors. Different methods are provided for selecting the linear combinations, while the splitting values are chosen by one of three criteria.

Usage

ODT(X, ...)

# S3 method for formula
ODT(
  formula,
  data = NULL,
  Xsplit = NULL,
  split = "auto",
  lambda = "log",
  NodeRotateFun = "RotMatPPO",
  FunDir = getwd(),
  paramList = NULL,
  glmnetParList = NULL,
  MaxDepth = Inf,
  numNode = Inf,
  MinLeaf = 10,
  Levels = NULL,
  subset = NULL,
  weights = NULL,
  na.action = na.fail,
  catLabel = NULL,
  Xcat = 0,
  Xscale = "Min-max",
  TreeRandRotate = FALSE,
  ...
)

# S3 method for default
ODT(
  X,
  y,
  Xsplit = NULL,
  split = "auto",
  lambda = "log",
  NodeRotateFun = "RotMatPPO",
  FunDir = getwd(),
  paramList = NULL,
  glmnetParList = NULL,
  MaxDepth = Inf,
  numNode = Inf,
  MinLeaf = 10,
  Levels = NULL,
  subset = NULL,
  weights = NULL,
  na.action = na.fail,
  catLabel = NULL,
  Xcat = 0,
  Xscale = "Min-max",
  TreeRandRotate = FALSE,
  ...
)

Arguments

X

An n by d numeric matrix (preferable) or data frame.

...

Optional parameters to be passed to the low level function.

formula

Object of class formula with a response describing the model to fit. If this is a data frame, it is taken as the model frame. (see model.frame)

data

Training data of class data.frame containing variables named in the formula. If data is missing it is obtained from the current environment by formula.

Xsplit

Splitting variables used to construct linear model trees. The default value is NULL and is only valid when split="linear".

split

The criterion used for splitting the nodes. "entropy": information gain and "gini": gini impurity index for classification; "mse": mean square error for regression; "linear": mean square error for linear model. 'auto' (default): If the response in data or y is a factor, "gini" is used, otherwise "mse" is assumed.

lambda

The argument of split is used to determine the penalty level of the partition criterion. Three options are provided including, lambda=0: no penalty; lambda=2: AIC penalty; lambda='log' (Default): BIC penalty. In Addition, lambda can be any value from 0 to n (training set size).

NodeRotateFun

Name of the function of class character that implements a linear combination of predictors in the split node. including

  • "RotMatPPO": projection pursuit optimization model (PPO), see RotMatPPO (default, model="PPR").

  • "RotMatRF": single feature similar to CART, see RotMatRF.

  • "RotMatRand": random rotation, see RotMatRand.

  • "RotMatMake": users can define this function, for details see RotMatMake.

FunDir

The path to the function of the user-defined NodeRotateFun (default current working directory).

paramList

List of parameters used by the functions NodeRotateFun. If left unchanged, default values will be used, for details see defaults.

glmnetParList

List of parameters used by the functions glmnet and cv.glmnet in package glmnet. glmnetParList=list(lambda = 0) is Ordinary Least Squares (OLS) regression, glmnetParList=list(family = "gaussian") (default) is regression model and glmnetParList=list(family = "binomial" or "multinomial") is classification model. If left unchanged, default values will be used, for details see glmnet and cv.glmnet.

MaxDepth

The maximum depth of the tree (default Inf).

numNode

Number of nodes that can be used by the tree (default Inf).

MinLeaf

Minimal node size (Default 10).

Levels

The category label of the response variable when split is not equal to 'mse'.

subset

An index vector indicating which rows should be used. (NOTE: If given, this argument must be named.)

weights

Vector of non-negative observational weights; fractional weights are allowed (default NULL).

na.action

A function to specify the action to be taken if NAs are found. (NOTE: If given, this argument must be named.)

catLabel

A category labels of class list in predictors. (default NULL, for details see Examples)

Xcat

A class vector is used to indicate which predictor is the categorical variable. The default Xcat=0 means that no special treatment is given to category variables. When Xcat=NULL, the predictor x that satisfies the condition "(length(table(x))<10) & (length(x)>20)" is judged to be a category variable.

Xscale

Predictor standardization methods. " Min-max" (default), "Quantile", "No" denote Min-max transformation, Quantile transformation and No transformation respectively.

TreeRandRotate

If or not to randomly rotate the training data before building the tree (default FALSE, see RandRot).

y

A response vector of length n.

Value

An object of class ODT containing a list of components::

  • call: The original call to ODT.

  • terms: An object of class c("terms", "formula") (see terms.object) summarizing the formula. Used by various methods, but typically not of direct relevance to users.

  • split, Levels and NodeRotateFun are important parameters for building the tree.

  • predicted: the predicted values of the training data.

  • projections: Projection direction for each split node.

  • paramList: Parameters in a named list to be used by NodeRotateFun.

  • data: The list of data related parameters used to build the tree.

  • tree: The list of tree related parameters used to build the tree.

  • structure: A set of tree structure data records.

    • nodeRotaMat: Record the split variables (first column), split node serial number (second column) and rotation direction (third column) for each node. (The first column and the third column are 0 means leaf nodes)

    • nodeNumLabel: Record each leaf node's category for classification or predicted value for regression (second column is data size). (Each column is 0 means it is not a leaf node)

    • nodeCutValue: Record the split point of each node. (0 means leaf nodes)

    • nodeCutIndex: Record the index values of the partitioning variables selected based on the partition criterion split.

    • childNode: Record the number of child nodes after each splitting.

    • nodeDepth: Record the depth of the tree where each node is located.

    • nodeIndex: Record the indices of the data used in each node.

    • glmnetFit: Record the model fitted by function glmnet used in each node.

References

Zhan, H., Liu, Y., & Xia, Y. (2022). Consistency of The Oblique Decision Tree and Its Random Forest. arXiv preprint arXiv:2211.12653.

Author

Yu Liu and Yingcun Xia

Examples

# Classification with Oblique Decision Tree.
data(seeds)
set.seed(221212)
train <- sample(1:209, 100)
train_data <- data.frame(seeds[train, ])
test_data <- data.frame(seeds[-train, ])
tree <- ODT(varieties_of_wheat ~ ., train_data, split = "entropy")
pred <- predict(tree, test_data[, -8])
# classification error
(mean(pred != test_data[, 8]))
#> [1] 0.09174312

# Regression with Oblique Decision Tree.
data(body_fat)
set.seed(221212)
train <- sample(1:252, 100)
train_data <- data.frame(body_fat[train, ])
test_data <- data.frame(body_fat[-train, ])
tree <- ODT(Density ~ ., train_data,
  split = "mse",
  NodeRotateFun = "RotMatPPO", paramList = list(model = "Log", dimProj = "Rand")
)
pred <- predict(tree, test_data[, -1])
# estimation error
mean((pred - test_data[, 1])^2)
#> [1] 0.0005496953

# Use "Z" as the splitting variable to build a linear model tree for "X" and "y".
set.seed(10)
cutpoint=50
X=matrix(rnorm(100*10),100,10)
age=sample(seq(20,80),100,replace = TRUE)
height=sample(seq(50,200),100,replace = TRUE)
weight=sample(seq(5,150),100,replace = TRUE)
Z=cbind(age=age,height=height,weight=weight)
mu=rep(0,100)
mu[age<=cutpoint]=X[age<=cutpoint,1]+X[age<=cutpoint,2]
mu[age>cutpoint]=X[age>cutpoint,1]+X[age>cutpoint,3]
y=mu+rnorm(100)
# Regression model tree
my.tree <- ODT(X=X, y=y, Xsplit=Z, split = "linear", lambda = 0,
NodeRotateFun = "RotMatRF",
glmnetParList=list(lambda = 0, family = "gaussian"))
pred <- predict(my.tree, X, Xsplit=Z)
# fitting error
mean((pred - y)^2)
#> [1] 0.9035932
mean((my.tree$predicted - y)^2)
#> [1] 0.9035932
# Classification model tree
y1 = (y>0)*1
my.tree <- ODT(X=X, y=y1, Xsplit=Z, split = "linear",lambda = 0,
               NodeRotateFun = "RotMatRF",MinLeaf = 10, MaxDepth = 5,
               glmnetParList=list(family = "binomial"))
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
#> Warning: one multinomial or binomial class has fewer than 8  observations; dangerous ground
(class <- predict(my.tree, X, Xsplit=Z, type="pred"))
#> Warning: number of rows of result is not a multiple of vector length (arg 1)
#>   [1] "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0"
#>  [19] "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0"
#>  [37] "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0"
#>  [55] "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0"
#>  [73] "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0" "0"
#>  [91] "0" "0" "0" "0" "0" "0" "0" "0" "0" "0"
(prob <- predict(my.tree, X, Xsplit=Z, type="prob"))
#> Warning: number of rows of result is not a multiple of vector length (arg 1)
#>        0 1
#>   [1,] 1 0
#>   [2,] 1 0
#>   [3,] 1 0
#>   [4,] 1 0
#>   [5,] 1 0
#>   [6,] 1 0
#>   [7,] 1 0
#>   [8,] 1 0
#>   [9,] 1 0
#>  [10,] 1 0
#>  [11,] 1 0
#>  [12,] 1 0
#>  [13,] 1 0
#>  [14,] 1 0
#>  [15,] 1 0
#>  [16,] 1 0
#>  [17,] 1 0
#>  [18,] 1 0
#>  [19,] 1 0
#>  [20,] 1 0
#>  [21,] 1 0
#>  [22,] 1 0
#>  [23,] 1 0
#>  [24,] 1 0
#>  [25,] 1 0
#>  [26,] 1 0
#>  [27,] 1 0
#>  [28,] 1 0
#>  [29,] 1 0
#>  [30,] 1 0
#>  [31,] 1 0
#>  [32,] 1 0
#>  [33,] 1 0
#>  [34,] 1 0
#>  [35,] 1 0
#>  [36,] 1 0
#>  [37,] 1 0
#>  [38,] 1 0
#>  [39,] 1 0
#>  [40,] 1 0
#>  [41,] 1 0
#>  [42,] 1 0
#>  [43,] 1 0
#>  [44,] 1 0
#>  [45,] 1 0
#>  [46,] 1 0
#>  [47,] 1 0
#>  [48,] 1 0
#>  [49,] 1 0
#>  [50,] 1 0
#>  [51,] 1 0
#>  [52,] 1 0
#>  [53,] 1 0
#>  [54,] 1 0
#>  [55,] 1 0
#>  [56,] 1 0
#>  [57,] 1 0
#>  [58,] 1 0
#>  [59,] 1 0
#>  [60,] 1 0
#>  [61,] 1 0
#>  [62,] 1 0
#>  [63,] 1 0
#>  [64,] 1 0
#>  [65,] 1 0
#>  [66,] 1 0
#>  [67,] 1 0
#>  [68,] 1 0
#>  [69,] 1 0
#>  [70,] 1 0
#>  [71,] 1 0
#>  [72,] 1 0
#>  [73,] 1 0
#>  [74,] 1 0
#>  [75,] 1 0
#>  [76,] 1 0
#>  [77,] 1 0
#>  [78,] 1 0
#>  [79,] 1 0
#>  [80,] 1 0
#>  [81,] 1 0
#>  [82,] 1 0
#>  [83,] 1 0
#>  [84,] 1 0
#>  [85,] 1 0
#>  [86,] 1 0
#>  [87,] 1 0
#>  [88,] 1 0
#>  [89,] 1 0
#>  [90,] 1 0
#>  [91,] 1 0
#>  [92,] 1 0
#>  [93,] 1 0
#>  [94,] 1 0
#>  [95,] 1 0
#>  [96,] 1 0
#>  [97,] 1 0
#>  [98,] 1 0
#>  [99,] 1 0
#> [100,] 1 0

# Projection analysis of the oblique decision tree.
data(iris)
tree <- ODT(Species ~ ., data = iris, split="gini",
            paramList = list(model = "PPR", numProj = 1))
print(round(tree[["projections"]],3))
#>       Sepal.Length Sepal.Width Petal.Length Petal.Width
#> proj1       -0.167      -0.215        0.763       0.587
#> proj2       -0.087      -0.223        0.676       0.697

### Train ODT on one-of-K encoded categorical data ###
# Note that the category variable must be placed at the beginning of the predictor X
# as in the following example.
set.seed(22)
Xcol1 <- sample(c("A", "B", "C"), 100, replace = TRUE)
Xcol2 <- sample(c("1", "2", "3", "4", "5"), 100, replace = TRUE)
Xcon <- matrix(rnorm(100 * 3), 100, 3)
X <- data.frame(Xcol1, Xcol2, Xcon)
Xcat <- c(1, 2)
catLabel <- NULL
y <- as.factor(sample(c(0, 1), 100, replace = TRUE))
tree <- ODT(X, y, split = "entropy", Xcat = NULL)
#> Warning: The categorical variable 1, 2 has been transformed into an one-of-K encode variables!
head(X)
#>   Xcol1 Xcol2          X1         X2          X3
#> 1     B     5 -0.04178453  2.3962339 -0.01443979
#> 2     A     4 -1.66084623 -0.4397486  0.57251733
#> 3     B     2 -0.57973333 -0.2878683  1.24475578
#> 4     B     1 -0.82075051  1.3702900  0.01716528
#> 5     C     5 -0.76337897 -0.9620213  0.25846351
#> 6     A     5 -0.37720294 -0.1853976  1.04872159
#>   Xcol1 Xcol2          X1         X2          X3
#> 1     B     5 -0.04178453  2.3962339 -0.01443979
#> 2     A     4 -1.66084623 -0.4397486  0.57251733
#> 3     B     2 -0.57973333 -0.2878683  1.24475578
#> 4     B     1 -0.82075051  1.3702900  0.01716528
#> 5     C     5 -0.76337897 -0.9620213  0.25846351
#> 6     A     5 -0.37720294 -0.1853976  1.04872159

# one-of-K encode each categorical feature and store in X1
numCat <- apply(X[, Xcat, drop = FALSE], 2, function(x) length(unique(x)))
# initialize training data matrix X
X1 <- matrix(0, nrow = nrow(X), ncol = sum(numCat))
catLabel <- vector("list", length(Xcat))
names(catLabel) <- colnames(X)[Xcat]
col.idx <- 0L
# convert categorical feature to K dummy variables
for (j in seq_along(Xcat)) {
  catMap <- (col.idx + 1):(col.idx + numCat[j])
  catLabel[[j]] <- levels(as.factor(X[, Xcat[j]]))
  X1[, catMap] <- (matrix(X[, Xcat[j]], nrow(X), numCat[j]) ==
    matrix(catLabel[[j]], nrow(X), numCat[j], byrow = TRUE)) + 0
  col.idx <- col.idx + numCat[j]
}
X <- cbind(X1, X[, -Xcat])
colnames(X) <- c(paste(rep(seq_along(numCat), numCat), unlist(catLabel),
  sep = "."
), "X1", "X2", "X3")

# Print the result after processing of category variables.
head(X)
#>   1.A 1.B 1.C 2.1 2.2 2.3 2.4 2.5          X1         X2          X3
#> 1   0   1   0   0   0   0   0   1 -0.04178453  2.3962339 -0.01443979
#> 2   1   0   0   0   0   0   1   0 -1.66084623 -0.4397486  0.57251733
#> 3   0   1   0   0   1   0   0   0 -0.57973333 -0.2878683  1.24475578
#> 4   0   1   0   1   0   0   0   0 -0.82075051  1.3702900  0.01716528
#> 5   0   0   1   0   0   0   0   1 -0.76337897 -0.9620213  0.25846351
#> 6   1   0   0   0   0   0   0   1 -0.37720294 -0.1853976  1.04872159
#>   1.A 1.B 1.C 2.1 2.2 2.3 2.4 2.5          X1         X2          X3
#> 1   0   1   0   0   0   0   0   1 -0.04178453  2.3962339 -0.01443979
#> 2   1   0   0   0   0   0   1   0 -1.66084623 -0.4397486  0.57251733
#> 3   0   1   0   0   1   0   0   0 -0.57973333 -0.2878683  1.24475578
#> 4   0   1   0   1   0   0   0   0 -0.82075051  1.3702900  0.01716528
#> 5   0   0   1   0   0   0   0   1 -0.76337897 -0.9620213  0.25846351
#> 6   1   0   0   0   0   0   0   1 -0.37720294 -0.1853976  1.04872159
catLabel
#> $Xcol1
#> [1] "A" "B" "C"
#> 
#> $Xcol2
#> [1] "1" "2" "3" "4" "5"
#> 
#> $Xcol1
#> [1] "A" "B" "C"
#>
#> $Xcol2
#> [1] "1" "2" "3" "4" "5"

tree <- ODT(X, y, split = "gini", Xcat = c(1, 2), catLabel = catLabel,NodeRotateFun = "RotMatRF")