Introduction

Have we done a budget for a case like this before? What other matters have we handled with a similar profile to this one? In which other files with characteristics like this one have we brought a summary judgment motion?

This type of question arises regularly in a law firm. Answers can sometimes be challenging. Searching databases via keywords or filters is often inefficient at best, and inadequate at worst. Hence the impetus to apply machine learning to this problem.

One approach worth exploring is the k-nearest neighbor algorithm. This is among the oldest and simplest machine learning algorithms. It is typically used for regression and classification problems. But it can also be used just to find analogues in a dataset.

In this post, I introduce the NearestNeighbors() class from the scikit-learn library. This class takes an example, x, and simply returns the element(s) in the dataset that most closely resembles it.

The K-Nearest Neighbors Algorithm

In a nutshell, KNN processes the elements of a dataset and plots each of them in a multi-dimensional space. The algorithm calculates the location of each element based on its features. The distances between these elements are then used to make predictions.

There is no shortage of excellent tutorials on the KNN algorithm. My favourites include Develop k-Nearest Neighbors in Python from Scratch on Machine Learning Mastery and The k-Nearest Neighbors Algorithm in Python on Real Python. Codecademy also addresses KNN in its Build a Machine Learning Model with Python skills path.

All of the tutorials I have studied address using KNN for classification or regression problems. For instance,

  • What kind of flower is this example likely to be given the measurements of particular features? (Classification)
  • What is the likely market price of this example of real estate given its various features such as square footage, number of bedrooms, etc? (Regression)

These tutorials also tend either to focus on coding the KNN algorithm from scratch in Python, or using the KNeighborsClassifier() and KNeighborsRegressor() classes from scikit-learn. All of these approaches rely on calculating "nearest neighbors", but this is not the output the algorithm produces. Rather, the algorithm provides a prediction regarding a classification (x is an example of category y) or a value for regression (x is valued at $y).

But what if the output you want is simply to know which elements in the dataset are the nearest neighbors to x? You could code this from scratch. Or you could use the NearestNeighbor() class from scikit-learn!

NearestNeighbors in Action

Let's apply NearestNeighbor() to a very simple dataset.

In the below snippet, we import our dependencies and create our sample dataset: five elements (rows), each with five features (columns). To create this dataset, we first define a dictionary with five keys (A to E), with each key associated with a list containing five values. This dictionary is then converted into a pandas DataFrame.

import pandas as pd
from sklearn.neighbors import NearestNeighbors
   
samples = {'A': [10, 20, 30, 40, 50],
        'B': [10, 20, 30, 40, 50],
        'C': [10, 20, 30, 40, 50],
        'D': [10, 20, 30, 40, 50],
        'E': [10, 20, 30, 40, 50]}
  
dataset = pd.DataFrame(samples)
print(dataset)
    A   B   C   D   E
0  10  10  10  10  10
1  20  20  20  20  20
2  30  30  30  30  30
3  40  40  40  40  40
4  50  50  50  50  50

Now suppose we have a new element, x, and we want to know which element in our dataset is the most similar (or nearest neighbor) to it. In this example, x has a value of 9 in each column.

In the next snippet, we create a new DataFrame comprising the features for x. Then, we instantiate an instance of NearestNeighbor(), assign it to the variable neigh, and fit it to the dataset.

Finally, we run our query using the kneighbors() method to find the nearest neighbor to x.

x = {'A': [9],
     'B': [9],
     'C': [9],
     'D': [9],
     'E': [9]}

x = pd.DataFrame(x)

neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(dataset)

print(neigh.kneighbors(x))
(array([[2.23606798]]), array([[0]]))

NearestNeighbor() returns 2.236 and 0, indicating the nearest neighbor to x is the element at row 0 of our dataset, and the distance between these elements is 2.236.

Let's try another example, y. This example has a value of 60 in each column.

y = {'A': [60],
     'B': [60],
     'C': [60],
     'D': [60],
     'E': [60]}

y = pd.DataFrame(y)

print(neigh.kneighbors(y))
(array([[22.36067977]]), array([[4]]))

As expected, the nearest neighbor to y is the element at row 4 of our dataset.

You can find more information about NearestNeighbor() in the documentation for scikit-learn.

Applying KNN to Legal Data

Admittedly, the preceding example is very abstract when considered in relation to how KNN can be implemented with legal data. Imagine each row of the dataset is one case, and each column is one feature regarding that case (e.g., jurisdiction, case type, etc.)

No sugarcoating: one of the trickiest parts to using KNN to find analogues in legal data is first putting the dataset together. This will usually involve a lot of wrangling. But, hey, it's good practice!

Final Thoughts

Have you experimented with using KNN on legal data? Hit me up! Let's compare notes.