## Regression using the embeddings

Regression means predicting a number, rather than one of the categories. We will predict the score based on the embedding of the review's text. We split the dataset into a training and a testing set for all of the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).

We're predicting the score of the review, which is a number between 1 and 5 (1-star being negative and 5-star positive).

In [2]:
import pandas as pd
import numpy as np

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error

df = pd.read_csv('output/embedded_1k_reviews.csv')
df['babbage_similarity'] = df.babbage_similarity.apply(eval).apply(np.array)

X_train, X_test, y_train, y_test = train_test_split(list(df.babbage_similarity.values), df.Score, test_size = 0.2, random_state=42)

rfr = RandomForestRegressor(n_estimators=100)
rfr.fit(X_train, y_train)
preds = rfr.predict(X_test)


mse = mean_squared_error(y_test, preds)
mae = mean_absolute_error(y_test, preds)

print(f"Babbage similarity embedding performance on 1k Amazon reviews: mse={mse:.2f}, mae={mae:.2f}")

Babbage similarity embedding performance on 1k Amazon reviews: mse=0.38, mae=0.39


In [26]:
bmse = mean_squared_error(y_test, np.repeat(y_test.mean(), len(y_test)))
bmae = mean_absolute_error(y_test, np.repeat(y_test.mean(), len(y_test)))
print(f"Dummy mean prediction performance on Amazon reviews: mse={bmse:.2f}, mae={bmae:.2f}")

Dummy mean prediction performance on Amazon reviews: mse=1.77, mae=1.04


We can see that the embeddings are able to predict the scores with an average error of 0.39 per score prediction. This is roughly equivalent to predicting 2 out of 3 reviews perfectly, and 1 out of three reviews by a one star error.

You could also train a classifier to predict the label, or use the embeddings within an existing ML model to encode free text features.