Random Forests For Image Classification
When we talk about image classification, the conversation usually turns quickly to convolutional neural networks (CNNs) and deep learning. These models have become the gold standard for interpreting visual data, and for good reason.
But what happens if we take a completely different route?
In this exploratory example, we'll tackle image classification with something unexpected: a Random Forest. While it is far from the ideal tool for this job, using it offers a great learning opportunity. It helps us understand how traditional machine learning methods behave with visual data and where their limits are.
We'll walk through:
- Why Random Forests are not typically used for image classification
- How to structure images as input for a Random Forest
- The surprising effectiveness of Random Forests on the MNIST digits dataset
- The less than impressive results on CIFAR-10 dataset
- How we can improve the results on CIFAR-10
- What we can learn from this kind of experiment
Why a Random Forest for Image Classification?
Let's be clear. You wouldn't use a Random Forest to classify images in a serious production setting. CNNs dominate image classification tasks because they:
- Learn spatial hierarchies in the data
- Handle complex patterns like edges, textures, and shapes
- Scale well to large datasets and high-resolution images
On the other hand, Random Forests:
- Treat each pixel as an independent feature
- Do not understand local structure or spatial relationships
- Require manual flattening or feature extraction
So why use them at all?
Because they are simple, fast, and a great way to build intuition. Before diving into deep learning, understanding how classical models behave with image data can reveal a lot about the problem and the data itself.
Python Prerequisites
Let's install and import the prerequisites so they are ready to use.
# %pip install --quiet --upgrade pip
# %pip install numpy --quiet
# %pip install PyArrow --quiet
# %pip install Pandas --quiet
# %pip install scikit-learn --quietfrom sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import numpy as npClassifying Handwritten Digits with a Random Forest
We'll use the MNIST dataset, which contains 70,000 28x28 grayscale images of handwritten digits (0 to 9). This is a classic benchmark in machine learning and deep learning. We'll use it to see how far a Random Forest can go on this relatively simple task.
Let's load the MNIST dataset from OpenML using the fetch_openml function from Sci-Kit Learn:
print("Loading MNIST dataset...")
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
mnist_X, mnist_y = mnist['data'], mnist['target']
mnist_y = mnist_y.astype('int')
print(f"Dataset loaded. Shape of mnist_X: {mnist_X.shape}, Shape of mnist_y: {mnist_y.shape}")
Loading MNIST dataset...
Dataset loaded. Shape of mnist_X: (70000, 784), Shape of mnist_y: (70000,)
The image data has been flattened into a one dimensional array of size 784 (i.e. 28 x 28). Let's take a look at some of the images and labels.
def show_mnist_data(images, labels, n=6):
plt.figure(figsize=(8, 4))
for i in range(n):
plt.subplot(1, n, i+1)
plt.imshow(images[i].reshape(28, 28), cmap='gray')
plt.title(f"Label: {labels[i]}")
plt.axis('off')
plt.tight_layout()
plt.show()
show_mnist_data(mnist_X, mnist_y)<Figure size 800x400 with 6 Axes>
Training The Digit Classifier
We can now split the data into train/test datasets, train a classifier on the raw pixel values, and determine the accuracy of our simple model.
mnist_X_train, mnist_X_test, mnist_y_train, mnist_y_test = (
train_test_split(mnist_X, mnist_y, test_size=0.2, random_state=42))
mnist_rf = RandomForestClassifier(n_estimators=100, random_state=42)
mnist_rf.fit(mnist_X_train, mnist_y_train)
mnist_y_pred = mnist_rf.predict(mnist_X_test)
print(f"Accuracy: {accuracy_score(mnist_y_test, mnist_y_pred):.2f}")
print("\nClassification Report:\n", classification_report(mnist_y_test, mnist_y_pred))
Accuracy: 0.97
Classification Report:
precision recall f1-score support
0 0.98 0.99 0.99 1343
1 0.98 0.98 0.98 1600
2 0.95 0.97 0.96 1380
3 0.96 0.95 0.96 1433
4 0.96 0.97 0.97 1295
5 0.97 0.96 0.97 1273
6 0.98 0.98 0.98 1396
7 0.97 0.97 0.97 1503
8 0.96 0.95 0.96 1357
9 0.96 0.95 0.95 1420
accuracy 0.97 14000
macro avg 0.97 0.97 0.97 14000
weighted avg 0.97 0.97 0.97 14000
An overall accuracy of 97% is pretty impressive for such a simple model (typically a CNN would achieve an accuracy of somewhere between 98.5% and 99.7%, depending on model complexity and training conditions).
The classification report also shows that the model performs consistently across all the labels.
Let's take a quick look at the structure by examining one of the decision trees that make up the forest:
decision_tree = mnist_rf.estimators_[0]
print(f"Number of nodes in the tree: {decision_tree.tree_.node_count}")
print(f"Depth of the tree: {decision_tree.tree_.max_depth}") # type: ignore
Number of nodes in the tree: 9377
Depth of the tree: 34
Visualizing What the Model Cares About
We can also visualize which pixels were most important to the model's decisions using feature importances.
importances = mnist_rf.feature_importances_
plt.imshow(importances.reshape(28, 28), cmap='hot')
plt.title("Pixel Importance Heatmap")
plt.colorbar()
plt.show()<Figure size 640x480 with 2 Axes>
This shows which areas of the image the model found most useful. It gives us a glimpse into the model’s decision process, though not as detailed or meaningful as what you would get from a CNN.
It is clear from the heatmap that the model has learned that the outside edges of the MNIST images are not useful when making predictions. Intuitively this makes sense as these areas of the images are mostly blank.
Let's take a look at what the model is predicting:
def show_mnist_predictions(images, labels, predictions, n=6):
plt.figure(figsize=(8, 4))
for i in range(n):
plt.subplot(1, n, i+1)
plt.imshow(images[i].reshape(28, 28), cmap='gray')
plt.title(f"Label: {labels[i]}\nPred: {predictions[i]}")
plt.axis('off')
plt.tight_layout()
plt.show()
show_mnist_predictions(mnist_X_test, mnist_y_test, mnist_y_pred)<Figure size 800x400 with 6 Axes>
And finally, let's take a look at where the model gets it wrong.
def show_mnist_errors(images, labels, predictions, n=6):
plt.figure(figsize=(8, 4))
error_count = 0
for i in range(len(labels)):
if labels[i] != predictions[i]:
plt.subplot(1, n, error_count + 1)
plt.imshow(images[i].reshape(28, 28), cmap='gray')
plt.title(f"Label: {labels[i]}\nPred: {predictions[i]}")
plt.axis('off')
error_count += 1
if error_count >= n:
break
plt.tight_layout()
plt.show()
show_mnist_errors(mnist_X_test, mnist_y_test, mnist_y_pred)<Figure size 800x400 with 6 Axes>
Even these incorrect predication aren't far off. If you squint at each of the labelled images above you can understand how the model make these mistakes.
As we've already said, the Random Forest did surprisingly strong for a model that has no spatial awareness. But keep in mind that MNIST is clean, small, and well-structured. On real-world images, the performance of this method is likely drop significantly.
Let's try something a little bit harder to show the limitations of this approach.
CIFAR-10: A Step Up in Complexity
To better understand the limitations of using a Random Forest for image classification, we’ll now turn to a more challenging dataset: CIFAR-10.
CIFAR-10 consists of 60,000 color images (32×32 pixels) in 10 classes, including airplanes, cars, birds, cats, and other common objects. Each image has three color channels (RGB) and contains more visual complexity and variation than the grayscale handwritten digits of MNIST.
Unlike MNIST, which is relatively simple, CIFAR-10 introduces challenges that require models to recognize textures, edges, object shapes, and spatial relationships across color channels. This is where the Random Forest approach starts to break down. Without any understanding of spatial structure, it struggles to generalise from pixel values alone.
In the next section, we’ll apply the same method we used for MNIST to CIFAR-10 and observe where and why it struggles.
First let's load the small version of CIFAR dataset from OpenML:
print("Loading CIFAR dataset...")
cifar = fetch_openml('CIFAR_10_small', version=1, as_frame=False)
cifar_X, cifar_y = cifar['data'], cifar['target']
cifar_y = cifar_y.astype('int')
print(f"Dataset loaded. Shape of X:{cifar_X.shape}, Shape of y: {cifar_y.shape}")
Loading CIFAR dataset...
Dataset loaded. Shape of X:(20000, 3072), Shape of y: (20000,)
Again Sci-Kit Learn has flattened the images into a one dimensional array but this time they are significantly bigger that the images from the simple MNIST dataset. Each image is of size 3072 (i.e. 32 x 32 pixels x 3 channels for red, blue, and green). These images are approximately 4 times bigger.
Let's take a look at some of the images and labels we will be asking our simple classifier to work with:
CIFAR_LABEL_NAMES =["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
def cifar_data_to_rgb(sample: np.ndarray) -> np.ndarray:
# CIFAR images are 32x32 with 3 color channels
red = sample[0:1024].reshape(32, 32)
green = sample[1024:2048].reshape(32, 32)
blue = sample[2048:3072].reshape(32, 32)
return np.dstack((red, green, blue))
def show_cifar_data(images, labels, n=6):
plt.figure(figsize=(8, 4))
for i in range(n):
plt.subplot(1, n, i+1)
plt.imshow(cifar_data_to_rgb(images[i]))
plt.title(f"Label: {CIFAR_LABEL_NAMES[labels[i]]}", fontsize=8)
plt.axis('off')
plt.tight_layout()
plt.show()
show_cifar_data(cifar_X, cifar_y)<Figure size 800x400 with 6 Axes>
As the images are quite small they appear blurry but there is enough for us to understand each image.
Training the CIFAR Image Classifier
We can now follow the same approach that we took with MNIST i.e. split the data into train/test datasets, train a classifier on the raw pixel values, and determine the accuracy of our simple model.
cifar_X_train, cifar_X_test, cifar_y_train, cifar_y_test = (
train_test_split(cifar_X, cifar_y, test_size=0.2, random_state=42))
cifar_rf = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=42)
cifar_rf.fit(cifar_X_train, cifar_y_train)
cifar_y_pred = cifar_rf.predict(cifar_X_test)
print(f"Accuracy: {accuracy_score(cifar_y_test, cifar_y_pred):.2f}")
print("\nClassification Report:\n", classification_report(cifar_y_test, cifar_y_pred, target_names=CIFAR_LABEL_NAMES))
Accuracy: 0.43
Classification Report:
precision recall f1-score support
airplane 0.49 0.55 0.52 401
automobile 0.47 0.50 0.48 410
bird 0.36 0.29 0.32 409
cat 0.25 0.21 0.23 378
deer 0.34 0.41 0.38 369
dog 0.37 0.34 0.36 374
frog 0.50 0.52 0.51 460
horse 0.49 0.45 0.47 376
ship 0.56 0.55 0.56 411
truck 0.44 0.48 0.46 412
accuracy 0.43 4000
macro avg 0.43 0.43 0.43 4000
weighted avg 0.43 0.43 0.43 4000
Our simple Random Forest classifier only achieved 43% accuracy on the CIFAR-10 dataset. While this is significantly better than random guessing (which would yield 10% for 10 classes), it is far below what modern models can achieve.
In addition, the classification report indicates that the model struggles with some classes. Precision for cats is particularly low at 25% accuracy.
For comparison, even a basic CNN trained from scratch on CIFAR-10 typically reaches 75–85% accuracy, and more advanced architectures like ResNet or DenseNet can exceed 90%. These models learn to extract and combine local spatial features, understand hierarchical patterns, and leverage the color channels more effectively.
This stark difference in performance highlights a fundamental limitation of Random Forests for image data: they treat pixels as independent features and have no built-in understanding of image structure or spatial relationships. As a result, they struggle with more complex images like those in CIFAR-10.
Let's take a look at the predictions being made by the model:
def show_cifar_predictions(images, labels, predictions, n=6):
plt.figure(figsize=(8, 4))
for i in range(n):
plt.subplot(1, n, i+1)
plt.imshow(cifar_data_to_rgb(images[i]))
plt.title(f"Label: {CIFAR_LABEL_NAMES[labels[i]]}\nPred: {CIFAR_LABEL_NAMES[predictions[i]]}", fontsize=8)
plt.axis('off')
plt.tight_layout()
plt.show()
# Show sample predictions
show_cifar_predictions(cifar_X_test, cifar_y_test, cifar_y_pred)<Figure size 800x400 with 6 Axes>
Hmm. Some clear and obvious mistakes here.
Improving Performance with HOG Features
While using raw pixel values as input features for a Random Forest is simple, it ignores one of the most important aspects of image data: spatial structure. Each pixel is treated independently, without any context about edges, shapes, or patterns. These elements that are often critical for recognizing objects in images.
To overcome this, we can try to extract handcrafted features that capture more meaningful information. One techniques we can use for this is Histogram of Oriented Gradients (HOG).
What is HOG?
HOG is a feature descriptor that summarizes the distribution of edge directions (or gradients) in small regions of an image. Instead of looking at pixel intensities directly, HOG focuses on how intensity changes, which helps it capture shapes and contours.
In practice, the image is divided into small cells, and for each cell, HOG calculates a histogram of gradient orientations. These histograms are then normalized and combined into a single feature vector that describes the image's local structure.
Why Use HOG?
- It captures edge and texture information, which is important for object recognition.
- It reduces dimensionality compared to raw pixels, while keeping relevant features.
- It makes traditional models like Random Forests more competitive, especially on more complex datasets like CIFAR-10 or real-world imagery.
In the next section, we'll apply HOG to our dataset and see whether it improves our Random Forest classifier's ability to recognize images but first let's take a look at what our CIFAR-10 dataset looks like when we apply the HOG algorithm from scikit-image to the images.
from skimage.feature import hog
pixels_per_cell=(8, 8)
cells_per_block=(3, 3)
def show_cifar_hog_data(images, labels, n=6):
plt.figure(figsize=(8, 4))
for i in range(n):
plt.subplot(1, n, i+1)
_, hog_image = hog(
cifar_data_to_rgb(images[i]),
pixels_per_cell=pixels_per_cell,
cells_per_block=cells_per_block,
visualize=True,
channel_axis=-1)
plt.imshow(hog_image, cmap='gray')
plt.title(f"Label: {CIFAR_LABEL_NAMES[labels[i]]}", fontsize=8)
plt.axis('off')
plt.tight_layout()
plt.show()
## show_cifar_hog_data(cifar_X_train_gray, cifar_y_train)
show_cifar_hog_data(cifar_X_train, cifar_y_train)<Figure size 800x400 with 6 Axes>
The resulting image gradients are not very interpretable for me. The second image looks a bit like the shape of a cat's head but the other images seem to have lost all their details.
The reason for this is simple: CIFAR images are small and low-resolution, while HOG was originally designed for larger, high-contrast grayscale images (like pedestrians in surveillance footage). At just 32×32 pixels, there is limited spatial detail to extract meaningful gradient patterns, and compressing these into histograms often discards subtle but important visual information. As a result, the HOG-transformed images lack the clarity and structure we might expect, making them difficult for humans to interpret effectively.
Let's see if the model can do better. First let's transform our train and test datasets:
def cifar_hog_features(images):
return np.array([
hog(
cifar_data_to_rgb(img),
pixels_per_cell=pixels_per_cell,
cells_per_block=cells_per_block,
channel_axis=-1)
for img in images])
cifar_X_train_hog_features = cifar_hog_features(cifar_X_train)
cifar_X_test_hog_features = cifar_hog_features(cifar_X_test)
print(f"Shape of cifar_X_train_hog: {cifar_X_test_hog_features.shape}")
Shape of cifar_X_train_hog: (4000, 324)
Applying the HOG algorithm to the data has significantly reduced the dimensionality of the raw image data. Each sample now has 324 features which is around an order of magnitude smaller that the original data.
Let's train a new model using these features and compare the results:
cifar_hog_rf = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=42)
cifar_hog_rf.fit(cifar_X_train_hog_features, cifar_y_train)
cifar_y_pred = cifar_hog_rf.predict(cifar_X_test_hog_features)
print(f"Accuracy: {accuracy_score(cifar_y_test, cifar_y_pred):.2f}")
print(
"\nClassification Report:\n",
classification_report(cifar_y_test, cifar_y_pred, target_names=CIFAR_LABEL_NAMES))
Accuracy: 0.50
Classification Report:
precision recall f1-score support
airplane 0.56 0.56 0.56 401
automobile 0.56 0.64 0.60 410
bird 0.43 0.33 0.37 409
cat 0.33 0.30 0.32 378
deer 0.41 0.47 0.43 369
dog 0.41 0.42 0.42 374
frog 0.56 0.62 0.59 460
horse 0.55 0.51 0.53 376
ship 0.56 0.58 0.57 411
truck 0.62 0.57 0.59 412
accuracy 0.50 4000
macro avg 0.50 0.50 0.50 4000
weighted avg 0.50 0.50 0.50 4000
A significant improvement from 43% to 50%, approximately 16% improvement in accuracy but still nowhere near what we would expect with a CNN based model.
Also, still looks like we have some problems with specific classes. Classification of cat's has improved but only to 33%.
Reflections
This series of experiments offered a practical look at how far we can push a Random Forest classifier on image data, and where it begins to fall short.
On MNIST, a dataset of simple, high-contrast handwritten digits, Random Forests performed surprisingly well. The model achieved accuracy around 97 percent using only raw pixel values. This result shows that for small, well-aligned, and relatively clean images, traditional machine learning models can still be effective without the need for complex feature extraction or deep learning.
However, when we moved to CIFAR-10, the limitations of this approach became clear. The Random Forest’s performance dropped to about 43 percent, showing how poorly it generalizes to more complex data. CIFAR-10 images include color, diverse shapes, and cluttered backgrounds. Without an understanding of spatial structure or texture, the model struggles to make sense of the data using raw pixels alone.
To improve this, we applied HOG feature extraction to the CIFAR-10 images, aiming to provide more meaningful input to the Random Forest. HOG does capture some useful edge and gradient information, but in this case, the improvement was only modest. It also made the images less interpretable. Because CIFAR images are small and visually complex, HOG features are less effective. The method was originally designed for larger grayscale images with clearer object outlines, such as those in pedestrian detection. As a result, its benefit in this setting is limited.
Conclusions
These experiments show both the potential and the limitations of using a Random Forest for image classification.
For simple tasks, Random Forests perform remarkably well. They are fast, require minimal tuning, and offer interpretable results. However, as we move toward more realistic and complex datasets, the shortcomings become obvious. The model cannot effectively capture the spatial and hierarchical patterns that are essential for accurate image understanding.
Adding handcrafted features such as HOG provides a small improvement, but it is not enough to match the performance of more advanced models. On modern image classification tasks, especially those involving color and high variability, CNNs remain the most effective approach. They are designed to handle the very challenges that traditional models cannot address.
In summary, Random Forests can be a useful learning tool and a quick baseline for simple image tasks. But when tackling real-world vision problems, deep learning methods are not just more powerful they are necessary.
The full source code of this notebook can be accessed on GitHub.