Categories: Machine Learning

Linear Regression and cross validation in Java using Weka

I stumbled upon a question in the internet about how to make price prediction based on price history in Android. Assuming the history size is quite small (few hundreds) and the attribute is not many (less than 20), I quickly thought that Weka Java API would be one of the easiest way to achieve this.

Unfortunately, I can’t easily find straightforward tutorial or example on this since most of them are for GUI version of Weka. So, I decided to whip up an example (using bleeding-edge weka-dev 3.9.2) and post the brief explanation here ?

I use a demand-forecasting (regression) dataset from UCI for this example. I choose this dataset because it has quite similar characteristics with price prediction.

To convert the dataset for Weka, I use methods below:

import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;

...

public static String DATASET_FILE = "Daily_Demand_Forecasting_Orders.csv";
public static int DATASET_SIZE = 60;
public static int DATASET_ATTRIBUTES_NUM = 13;

...

private Instances loadDataset() throws RuntimeException {
    Instances dataset = null;
    BufferedReader br = null;
    FileReader fr = null;
    try {
        ClassLoader classLoader = getClass().getClassLoader();
        fr = new FileReader(classLoader.getResource(DATASET_FILE).getPath());
        br = new BufferedReader(fr);
        String sCurrentLine;
        int line = 1;

        dataset = this.createEmptyDataset();
        while ((sCurrentLine = br.readLine()) != null) {
            if (line > 1) {
                try {
                    double[] values = new double[DATASET_ATTRIBUTES_NUM];
                    int i = 0;
                    for (String val : sCurrentLine.split(";")) {
                        values[i] = Double.parseDouble(val);
                        i++;
                    }
                    dataset.add(new DenseInstance(1.0, values));
                } catch (NumberFormatException ex) {
                    System.err.println(ex.getMessage());
                }
            }
            line++;
        }
        br.close();
    } catch (final Exception e) {
        throw new RuntimeException(e);
    } finally {
        try {
            if (br != null) br.close();
            if (fr != null) fr.close();
        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }
    return dataset;
}

private Instances createEmptyDataset() {
 ArrayList<Attribute> header = this.createHeader();
 Instances instances = new Instances(DATASET_FILE, header, DATASET_SIZE);
 instances.setClassIndex(DATASET_ATTRIBUTES_NUM - 1);
 return instances;
}

Where the createHeader() is actually dataset attributes definition:

private ArrayList<Attribute> createHeader() {
    ArrayList<Attribute> header = new ArrayList<>();
    header.add(new Attribute("Week_of_the_month"));
    header.add(new Attribute("Day_of_the_week_"));
    header.add(new Attribute("Non_urgent_order"));
    header.add(new Attribute("Urgent_order"));
    header.add(new Attribute("Order_type_A"));
    header.add(new Attribute("Order_type_B"));
    header.add(new Attribute("Order_type_C"));
    header.add(new Attribute("Fiscal_sector_orders"));
    header.add(new Attribute("Orders_from_the_traffic_controller_sector"));
    header.add(new Attribute("Banking_orders_(1)"));
    header.add(new Attribute("Banking_orders_(2)"));
    header.add(new Attribute("Banking_orders_(3)"));
    header.add(new Attribute("Target_(Total_orders)"));
    return header;
}

Then we can use loadDataset() method to build LinearRegression model:

import weka.classifiers.functions.LinearRegression;
import weka.core.Instances;


...

Instances dataset = loadDataset();
LinearRegression lr = new LinearRegression();
lr.setRidge(1.0E-8);
lr.buildClassifier(dataset);

Finally, we can use the model, lr, to predict a data:

import weka.core.Attribute;
import weka.core.DenseInstance;

...

double[] data = new double[]{1.0, 4.0, 316.307, 223.270, 61.543, 175.586, 302.448, 0.0, 65556.0, 44914.0, 188411.0, 14793.0, 539.577};
double expectation = data[data.length - 1];
DenseInstance instance = new DenseInstance(1.0, data);
double prediction = lr.classifyInstance(instance)

Or if you want to evaluate the model performance, you can do k-fold cross validation and check the error rate. In this example, I do 10 folds cross validation and measure the root mean square error (RMSE):

Evaluation evaluation = new Evaluation(dataset);
evaluation.crossValidateModel(lr, dataset, 10, new Random(1));
double rmse = evaluation.rootMeanSquaredError();

In the full example, I also use normalization to make sure each values belongs to the same scale. On many cases, this may also improve model’s performance.

Enjoy! ☕

0 0 votes
Article Rating
yohanes.gultom@gmail.com

Share
Published by
yohanes.gultom@gmail.com
Tags: javaweka

Recent Posts

Get Unverified SSL Certificate Expiry Date with Python

Getting verified SSL information with Python (3.x) is very easy. Code examples for it are…

3 years ago

Spring Data Couchbase 4 Multibuckets in Spring Boot 2

By default, Spring Data Couchbase implements single-bucket configuration. In this default implementation, all POJO (Plain…

3 years ago

Firebase Auth Emulator with Python

Last year, Google released Firebase Auth Emulator as a new component in Firebase Emulator. In…

4 years ago

Google OIDC token generation/validation

One of the authentication protocol that is supported by most of Google Cloud services is…

4 years ago

Fast geolocation query with PostGIS

If you need to to add a spatial information querying in your application, PostGIS is…

4 years ago

Auto speech-to-text (Indonesian) with AWS Transcribe and Python

Amazon Web Service Transcribe provides API to automatically convert an audio speech file (mp3/wav) into…

5 years ago