Categories: Machine Learning

Keras VGG16 flat-features extractor

Using pre-trained deep learning model as feature extractor is a proven way to improve classification accuracy. One of the famous model is Oxford’s VGG16, which is trained using million images to recognize 1,000 classes ranging from animals, vehicles and other stuffs.

Now, to use VGG16 as part of another neural network is relatively easy, especially if you are using Keras. You can just remove the top layers (the fully-connected layers, used as classifier) and take the output (in tensors/multidimensional matrix), pass it as an input to your model as shown by some nice examples here.

But what if we want to use non-neural network classifiers which is not provided by Keras?

One of the easiest solution that I found is by appending Flatten layer on top of stripped VGG16 (no top) then call predict on the dataset to receive a numpy array of flat (1D) features:

from keras.applications.vgg16 import VGG16
from keras.models import Model
from keras.layers import Input, Flatten

# VGG16 standard input shape
EXPECTED_DIM = (224, 224, 3)

vgg16 = VGG16(weights='imagenet', include_top=False)
input = Input(shape=EXPECTED_DIM, name='input')
output = vgg16(input)
x = Flatten(name='flatten')(output)
extractor = Model(inputs=input, outputs=x)

# dataset is a numpy array of tensor shaped EXPECTED_DIM
# features will be a numpy array of (dataset_rows, 25088)
features = extractor.predict(dataset)

Then you can pass the features (and the labels) to classifier from other library such as the famous scikit-learn easily. You can also use the same method from other pre-trained models generously provided by Keras.

Cheers! ?

 

0 0 votes
Article Rating
yohanes.gultom@gmail.com

Share
Published by
yohanes.gultom@gmail.com

Recent Posts

Get Unverified SSL Certificate Expiry Date with Python

Getting verified SSL information with Python (3.x) is very easy. Code examples for it are…

3 years ago

Spring Data Couchbase 4 Multibuckets in Spring Boot 2

By default, Spring Data Couchbase implements single-bucket configuration. In this default implementation, all POJO (Plain…

3 years ago

Firebase Auth Emulator with Python

Last year, Google released Firebase Auth Emulator as a new component in Firebase Emulator. In…

4 years ago

Google OIDC token generation/validation

One of the authentication protocol that is supported by most of Google Cloud services is…

4 years ago

Fast geolocation query with PostGIS

If you need to to add a spatial information querying in your application, PostGIS is…

4 years ago

Auto speech-to-text (Indonesian) with AWS Transcribe and Python

Amazon Web Service Transcribe provides API to automatically convert an audio speech file (mp3/wav) into…

5 years ago