Easy-to-interpret classification trees for single-cell machine learning

Earlier we demonstrated How a perfect single-cell subpopulation random forest separator is refined ,as well as LASSO regression can also be used for single-cell classification The two machine learning algorithms can be used as single-cell classifiers, and the effect is leveraged. And also tried a variety of machine learning algorithms, such as: SVM single-cell classifier not lost to LASSO

Whether it is random forest, LASSO regression, or support vector machine, their models are a bit abstract, and it is not easy to explain clearly intuitively. But the decision tree model we are going to introduce next is different.

Train a decision tree model

First, copy paste the previous How a perfect single-cell subpopulation random forest separator is refined , you can divide the single-cell expression matrix into training set and test set, then simply install and load the rpart package, and run the rpart function inside:


df = cbind(target=target,as.data.frame(predictor_data))
multivariate <- paste(colnames(df)[-1],collapse = '-') 
s <-  paste0('target ~ ', multivariate )
s # as.formula(s) 
fit <- rpart( target ~. ,  
             data=df, method="class" ) 

save(fit,file = 'rpart_output.Rdata')


It can be seen that the expression matrix we input is the expression information of 2000 genes in nearly 2000 cells:

>  dim(predictor_data)
[1] 1977 2000
> predictor_data[1:4,1:4]
                    ISG15     CPSF3L     MRPL20      ATAD3C
AAACATACAACCAC -0.8353028 -0.2741145  1.5056616 -0.04970561
AAACATTGAGCTAC -0.8353028 -0.2741145 -0.5625993 -0.04970561
AAACATTGATCAGC  0.3922351 -0.2741145  1.2450489 -0.04970561
AAACCGTGCTTCCG  2.2197621 -0.2741145 -0.5625993 -0.04970561

Our decision tree model is to combine these 2000 genes to divide the classification of cells. Let's simply visualize this effect:

rpart.plot(fit, branch=1, branch.type=2, type=2, extra=102,  
           shadow.col="gray", box.col="green",  
           border.col="blue", split.col="red",  
           split.cex=1.2 )


As shown below, although there are 6 unicellular subpopulations, only 5 genes are enough to distinguish them:

5 genes are enough to tell them apart

Look at the model effect on the test set

Similarly, the trained model also needs to see the effect in another data set:

test_outputs = predict(fit, as.data.frame(test_expr))
head( test_outputs ) 
pred_y = colnames(test_outputs)[apply(test_outputs, 1, which.max)]
pred_y = factor(pred_y,levels = levels(test_y))

pdf('rpart-performance.pdf',width = 10)

It can be seen that the problem is still the mixing of CD8 and NK cells, as well as the mixing of CD4 and CD8, which is currently unsolvable:

Incorporation of CD8 and NK cells

We can simply visualize the 5 genes of the previous decision tree model:

sce=CreateSeuratObject(counts = t(predictor_data) )
Idents(sce) = target
cg = c('FTL','HLA-DRA','NKG7','FCER1G','CST3')
VlnPlot(sce,features = cg , 
        stack = T) 
ggplot2::ggsave('rpart_models_VlnPlot.pdf' )


As follows:

The 5 Genes of a Decision Tree Model

Compared with the previous decision tree model:

  • It is indeed the FTL gene that distinguishes monocytes from other cells
  • Then HLA-DRA can distinguish B cells and dendritic cells from other cells, and the distinction between B cells and dendritic cells depends on CST3
  • Then CD4 in T cells is distinguished by NKG7, and then CD8 and NK cells are distinguished by FCER1G

Such a model is very easy to explain, compared to the previous random forest, LASSO regression, support vector machine.

Tags: Deep Learning Machine Learning neural networks AI Decision Tree

Posted by allspiritseve on Tue, 28 Feb 2023 14:27:35 +1030