Implementing Weighted KNN (WDKNN) in MATLAB

Raghda (Merry) Al taei
5 min readOct 13, 2024

--

The K-nearest neighbors (KNN) algorithm is widely used for classification tasks but struggles with imbalanced datasets. This article introduces the Weighted KNN (WDKNN) algorithm, which enhances KNN by addressing the challenges posed by imbalanced data.

What is Weighted KNN (WDKNN)?

Weighted KNN (WDKNN) improves upon traditional KNN by assigning weights to neighbors based on their distances. This approach emphasizes minority class instances, helping to achieve better classification accuracy for underrepresented data points.

Method Overview

The method proposed in the paper includes the following steps:

  1. Weight Calculation: Assign weights to the K nearest neighbors based on their distance, often using the inverse of the distance.
  2. Weighted Voting: Compute weighted votes for each class based on the weights assigned to each neighbor.
  3. Normalization: Normalize the weights to ensure fair contribution from all neighbors.
  4. Classification Decision: The class with the highest weighted vote is chosen as the predicted class.

MATLAB Implementation

Here’s a MATLAB implementation of the WDKNN algorithm. This code assumes you have 7 datasets, each represented as matrices, with numerical features and a final column for labels.

Step 1: Create the WDKNN Function

Create a new file named WDKNN.m and implement the following code:

clc;
clear all;
close all;

% Number of datasets and initialize accuracy array
num_datasets = 7;
accuracies = zeros(num_datasets, 1);

for i = 1:num_datasets
disp(['Processing Dataset ', num2str(i)]);

% Load dataset (assuming they are stored in .mat format)
load(['dataArray_', num2str(i), '.mat']); % Load your dataset matrix
data = dataArray; % Assume combinedData is a matrix

% Separate features and labels
X = data(:, 1:end-1); % Features (all columns except last)
y = data(:, end); % Labels (last column)

% Split the dataset into training and testing sets (80-20 split)
cv = cvpartition(size(X, 1), 'HoldOut', 0.2);
idx = cv.test;
X_train = X(~idx, :);
y_train = y(~idx);
X_test = X(idx, :);
y_test = y(idx);

% Adjust class weights for balancing (inverse class frequencies)
class_weights = [1 / sum(y_train == 0), 1 / sum(y_train == 1)];

% Set the number of neighbors
k = 5; % Adjust this as needed

% Initialize predictions
predictions = zeros(size(X_test, 1), 1);

% Perform predictions using WKNN
for j = 1:size(X_test, 1)
% Calculate distances between the test point and all training samples
distances = sqrt(sum((X_train - X_test(j, :)).^2, 2));

% Sort distances and get indices of the k nearest neighbors
[~, sorted_indices] = sort(distances, 'ascend'); % Ascending order
k_indices = sorted_indices(1:k); % Get the indices of the k smallest distances

% Extract labels of the k nearest neighbors
k_nearest_labels = y_train(k_indices);

% Calculate weights (inverse distance)
weights = 1 ./ (distances(k_indices) + 1e-10); % Avoid division by zero

% Adjust weights by class frequency
adjusted_weights = weights .* class_weights(k_nearest_labels + 1); % +1 for class indexing

% Weighted voting to determine the class of the test instance
weighted_sum = sum(adjusted_weights .* k_nearest_labels);
total_weight = sum(adjusted_weights);

% Predict class based on weighted sum
if weighted_sum >= total_weight / 2
predictions(j) = 1; % Predict as 'positive' (1)
else
predictions(j) = 0; % Predict as 'negative' (0)
end
end

% Evaluate accuracy
accuracies(i) = sum(predictions == y_test) / length(y_test);
fprintf('Dataset %d Accuracy: %.2f%%\n', i, accuracies(i) * 100);
end

% Display overall results
disp('Accuracies for all datasets:');
disp(accuracies);

% Plotting the accuracies
figure;
bar(accuracies * 100); % Convert to percentage for plotting
title('WKNN Classification Accuracies for Datasets');
xlabel('Datasets');
ylabel('Accuracy (%)');
xticks(1:num_datasets);
xticklabels({'Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4', 'Dataset 5', 'Dataset 6', 'Dataset 7'});
ylim([0 110]); % Set y-axis limits from 0 to 110%
grid on;

% Annotate bars with accuracy values
for k = 1:num_datasets
text(k, accuracies(k) * 100 + 2, sprintf('%.2f%%', accuracies(k) * 100), ...
'HorizontalAlignment', 'center', 'FontSize', 10);
end

Step 2: Load and Process Datasets

Create another MATLAB script, say run_wdknn.m, to load your datasets, train the model, and evaluate the accuracy:

clc;
clear all;

% List of .dat files to read
fileNames = {'ecoli1.dat', 'glass0.dat', 'glass1.dat', ...
'iris0.dat', 'pima.dat', ...
'wisconsin.dat', 'yeast1.dat'};

% Number of datasets
num_datasets = length(fileNames);

% Initialize cell array to store the converted data arrays
dataArrays = cell(num_datasets, 1);

for f = 1:num_datasets
% Open the .dat file
fileID = fopen(fileNames{f}, 'r');

% Check if the file opened successfully
if fileID == -1
error('Could not open file: %s', fileNames{f});
end

% Read the entire file as a cell array of strings
data = textscan(fileID, '%s', 'Delimiter', '\n');

% Close the file
fclose(fileID);

% Preallocate arrays
numRows = length(data{1}); % Total number of rows
categoricalData = cell(numRows, 1); % Preallocate for categorical values

% Create a temporary numeric matrix for this file
tempNumericMatrix = []; % Initialize empty matrix for numeric values

% Loop through each entry and parse the data
for i = 1:numRows
% Split the string by commas
splitData = strsplit(data{1}{i}, ',');

% Convert all but the last element to numbers
numericParts = str2double(splitData(1:end-1)); % All but last element are numeric

% Append to tempNumericMatrix; this will be a row
tempNumericMatrix = [tempNumericMatrix; numericParts];

% Store the last element (categorical) in categoricalData
categoricalData{i} = splitData{end};
end

% Determine the number of numeric columns
numNumericCols = size(tempNumericMatrix, 2);

% Create a table from the numeric data
numericTable = array2table(tempNumericMatrix, 'VariableNames', ...
strcat('Feature', string(1:numNumericCols))); % Name numeric features

% Add the categorical column as a separate variable
categoricalTable = table(categorical(categoricalData), 'VariableNames', {'Label'});

% Combine numeric and categorical tables
combinedData = [numericTable, categoricalTable]; % Concatenate tables

% Check the last column type and convert categorical to numeric
if iscategorical(combinedData{:, end})
% Create a numeric array to store the converted labels
numericLabels = zeros(height(combinedData), 1);

% Set 1 for 'positive' and 0 for 'negative'
numericLabels(combinedData{:, end} == 'positive') = 1;
numericLabels(combinedData{:, end} == 'negative') = 0;

% Replace the last column with the numeric labels
combinedData = [combinedData(:, 1:end-1), array2table(numericLabels, 'VariableNames', {'Labels'})];
end

% Convert the modified table to a numeric array
dataArray = table2array(combinedData);

% Save the combined data to a .mat file with a unique name
save(['dataArray_', num2str(f), '.mat'], 'dataArray');

% Optionally, display the combined data
disp(['Data for ', fileNames{f}, ':']);
disp(combinedData); % Display the combined data
end

Explanation of the Code

  1. Class Definition: The WDKNN class implements the Weighted KNN algorithm with methods to fit the model and make predictions.
  2. Weight Calculation: The weights are computed as the inverse of the distances to emphasize closer neighbors.
  3. Prediction Method: For each instance in the test set, the algorithm finds the K nearest neighbors, calculates the weights, and performs weighted voting to decide the predicted class.
  4. Data Loading and Processing: The script assumes that you have your datasets stored as .mat files and loads each one for processing.
  5. Evaluation: The accuracy of the predictions is computed and printed for each dataset.

Conclusion

The WDKNN algorithm enhances the KNN approach by emphasizing the importance of minority classes in imbalanced datasets. By implementing this algorithm in MATLAB, you can achieve improved classification results. Feel free to adjust parameters and experiment with different datasets to see how WDKNN performs in your specific applications.

--

--

Raghda (Merry) Al taei
Raghda (Merry) Al taei

Written by Raghda (Merry) Al taei

I am a Data Scientist/Analyst with a master's degree in computer engineering (AI) from AmirKabir University.

No responses yet