Skip to main content

Command Palette

Search for a command to run...

How to build a simple multi-label text classifier

Updated
6 min read
How to build a simple multi-label text classifier
K
On kevincoder.co.za, I write about my journey as a developer working across Django, Go, and everything in between, from large-scale systems to small, useful tools. Programming has been my passion for over 15 years now. I love learning new skills and am thrilled I can share these with you! A big part of my career was built on knowledge shared by others; in the main, open-source projects, forums, and communities like Stack Overflow. This blog is my way of contributing back and sharing what I’ve learned along the way.

While the world goes crazy over chatGPT and all the cool AI models we see launching every few weeks, it can be tempting to get caught up in the hype cycle but never actually build anything tangible.

In this article, I'll show you how to build a simple but powerful multi-label text classifier. This type of classifier may not be the most accurate or intelligent for that matter but it doesn’t need a GPU and is pretty fast.

depending on your dataset and needs, this approach can still be useful nonetheless, especially for use cases such as binary classification (think spam or not spam), product categorizations, and tagging.

To get started you will need to have some sort of dataset that contains labeled data. I will be using a Kaggle dataset. If you are not familiar with Kaggle, it's basically a data science community where you share various datasets and machine learning code as well.

Preparing our dataset

To get started we'll use the following dataset:

https://www.kaggle.com/datasets/thedevastator/product-prices-and-sizes-from-walmart-grocery

If you examine this dataset further, you'll notice it's a CSV file containing "PRODUCT_NAME" and "CATEGORY", which is exactly what we need to build our classifier.

index,SHIPPING_LOCATION,DEPARTMENT,CATEGORY,SUBCATEGORY,BREADCRUMBS,SKU,PRODUCT_URL,PRODUCT_NAME,BRAND,PRICE_RETAIL,PRICE_CURRENT,PRODUCT_SIZE,PROMOTION,RunDate,tid
0,79936,Deli,"Hummus, Dips, & Salsa",,"Deli/Hummus, Dips, & Salsa",110895339,https://www.walmart.com/ip/Marketside-Roasted-Red-Pepper-Hummus-10-Oz/110895339?fulfillmentIntent=Pickup,"Marketside Roasted Red Pepper Hummus, 10 Oz",Marketside,2.67,2.67,10,,2022-09-11 21:20:04,16163804
1,79936,Deli,"Hummus, Dips, & Salsa",,"Deli/Hummus, Dips, & Salsa",105455228,https://www.walmart.com/ip/Marketside-Roasted-Garlic-Hummus-10-Oz/105455228?fulfillmentIntent=Pickup,"Marketside Roasted Garlic Hummus, 10 Oz",Marketside,2.67,2.67,10,,2022-09-11 21:20:04,16163805
2,79936,Deli,"Hummus, Dips, & Salsa",,"Deli/Hummus, Dips, & Salsa",128642379,https://www.walmart.com/ip/Marketside-Classic-Hummus-10-Oz/128642379?fulfillmentIntent=Pickup,"Marketside Classic Hummus, 10 Oz",Marketside,2.67,2.67,10,,2022-09-11 21:20:04,16163806
3,79936,Deli,"Hummus, Dips, & Salsa",,"Deli/Hummus, Dips, & Salsa",366126367,https://www.walmart.com/ip/Marketside-Everything-Hummus-10-oz/366126367?fulfillmentIntent=Pickup,"Marketside Everything Hummus, 10 oz",Marketside,2.67,2.67,10,,2022-09-11 21:20:04,16163807

For loading and parsing this data we’ll use Pandas because it’s kinda the de facto standard in Python for handling large datasets and files:

import pandas as pd
import string
import re

def clean_text(sentence):
    cleaned = sentence.lower().strip()
    cleaned = cleaned.translate(str.maketrans('', '', string.punctuation))
    cleaned = re.sub(r'[^\x00-\x7f]',r' ',cleaned)
    cleaned = cleaned.replace("  ", " ")
    return cleaned

def loadDataset():
    df = pd.read_csv("./WMT_Grocery_202209.csv")
    with open("./dataset.txt", 'w') as f:
        for index, row in df.iterrows():
            f.write(
                "__label__%s %s\n" % (
                     clean_text(row['CATEGORY']).replace(" ", "_"),
                    clean_text(row['PRODUCT_NAME'])
                )
            )

In the above code, we first load our dataset with padas, then loop through each row and generate a fasttext line item, so essentially we're converting the CSV to a fasttext format.

The format we generate is as follows:

__label__hummus_dips_salsa herdez chipotle salsa cremosa 153 oz
__label__hummus_dips_salsa lays french onion dip 23 oz jar
__label__hummus_dips_salsa yucatan authentic guacamole squeeze 14 oz pouch
__label__hummus_dips_salsa beaver brand hot cream horseradish 12 oz
__label__hummus_dips_salsa marketside smokehouse burnt ends dip 12 oz
__label__hummus_dips_salsa dolci frutta hard chocolate shell 8 oz
__label__hummus_dips_salsa diablo verde salsa medium creamy cilantro sauce 125 oz jar
__label__hummus_dips_salsa goverden spicy guacamole 12oz
__label__hummus_dips_salsa goverden classic guacamole 12oz
__label__energy_drinks red bull energy drink 84 fl oz 12 pack
__label__energy_drinks red bull energy drink 84 fl oz 4 pack
__label__energy_drinks monster energy green original 16 f

We also sanitize the data to strip out punctuation and make everything lowercase, this will help fasttext to label our data with a greater deal of accuracy.

Building our classifier

Fasttext is awesome! As the name suggests this library will allow you to rapidly train on millions of lines of text just using your regular old CPU.

I trained a classifier on a machine with an i5 processor and 12 gigs of RAM. The dataset contained 10 million lines and this took under 3 hours to finish, which was relatively fast. Once training is done, the actual classification takes just mere milliseconds to finish.

The one downside of using Fasttext is that it's not very accurate with a small dataset - you need to train on millions of examples, unlike an LLM which may not even need fine-tuning depending on the classification task.

Still, if you need to classify hundreds of thousands or even millions of products, an LLM can get really expensive quite quickly.

Training Fasttext is really easy, we just need to create a model instance and pass it the input file we created earlier:

import fasttext

def trainModel():
    model = fasttext.train_supervised(
        input="dataset.txt",
        epoch=25,
        wordNgrams=3,
        bucket=200000,
        dim=100,
        loss="ova",
        thread=6
    )

    model.save_model("./model")

I won't go too much into detail on all the arguments passed to the trainer. You can read more here. There are 3 important settings that you need to know:

  • epoch: This is basically, how many times the classifier loops through the entire dataset. Every time it does a loop, the classifier can self-correct and improve its accuracy. You should adjust this based on your "AVG LOSS" that you'll see in the terminal when Fasttext is running. Increase this number to try and get the lowest amount of loss.

  • wordNgrams: The model will look at X number of words together, to try and understand meaning, instead of a single word. I find 3 works nicely for most use cases but you can toggle this based on your dataset and the avg loss.

  • dim: under the hood, Fasttext is vectorizing the text, which means it takes each word and converts that to a mathematical value, therefore this allows for running mathematical algorithms against the data to compare and determine the similarity of text.

Predicting labels

Now that we have a trained model, we can easily load this model and run some predictions:

import fasttext

def predict(text):
    model = fasttext.load_model("./model")
    result = model.predict(clean_text(text))
    (labels, stats) = result

    results = []
    for i,label in enumerate(labels):
        accuracy = stats[i]*100
        results.append({"label": label, "accuracy": accuracy})

    return results

If I pass in "robert mondavi private selection chardonnay" - I get back the following prediction:

[{'label': '__label__wine', 'accuracy': 16.452647745609283}]

As you play around with the predictions you'll notice some weird variations, as pointed out earlier, fasttext is not really the most accurate of classifiers. It all depends on your dataset. Sometimes it performs really well and at other times really poorly.

Alternative approaches

If the accuracy of Fasttext is not really working for you, and using an LLM is too costly at scale, an alternative approach would be to use vector embeddings. You could use similar examples as per the CSV file we used earlier, vectorize each product name, and store those in Qdrant or Redis.

Thereafter, use cosine similarity to compare and find the most closely related product, which in turn will give you the relevant label.

Example:

Product A: Apple Macbook PRO 256GB 2020 edition, Category: Macbooks

Product B: Macbook PRO 256GB 13”

Even though these have variations in the title, a vector embedding comparison should easily identify that they are a similar product and category Product B as a “Macbook”.

You can probably use an LLM to build the initial dataset of examples and feed that into Qdrant/Redis.

As you verify products, feed the data back into your Qdrant/Redis datastore, and over time, the accuracy will get to a level closer to 100%.

More from this blog

Kevin Coder | tutorials, thought experiments & tech ramblings

37 posts