Predict Price of Diamonds

Analysis
R-code
EDA
Modeling
Machine Learning
Author

Ajay Shankar A

Published

December 10, 2023

1 Predicting Diamonds Price

1.1 Introduction

Building a model to predict the price of the diamonds using tidymodels.

Diamonds data set is readily available to use through the ggplot2 library in the tidyverse and we will be using this data set predict the prices of the other diamonds.

In the data set various parameters of diamonds are given and each of these parameters may or may not effect the price of the diamonds.

library(tidyverse)
library(plotly)
library(DT)

data("diamonds")
diamonds
# A tibble: 53,940 × 10
   carat cut       color clarity depth table price     x     y     z
   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
 1  0.23 Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
 2  0.21 Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
 3  0.23 Good      E     VS1      56.9    65   327  4.05  4.07  2.31
 4  0.29 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
 5  0.31 Good      J     SI2      63.3    58   335  4.34  4.35  2.75
 6  0.24 Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48
 7  0.24 Very Good I     VVS1     62.3    57   336  3.95  3.98  2.47
 8  0.26 Very Good H     SI1      61.9    55   337  4.07  4.11  2.53
 9  0.22 Fair      E     VS2      65.1    61   337  3.87  3.78  2.49
10  0.23 Very Good H     VS1      59.4    61   338  4     4.05  2.39
# ℹ 53,930 more rows

Data has over 50,000 observations which is good for modeling.

1.2 Exploring the data

The diamonds data set is available to explore in ggplot2 library as mentioned above.

Let’s check for NA’s before exploring the data

diamonds %>% map( ~sum(is.na(.))) %>% unlist()
  carat     cut   color clarity   depth   table   price       x       y       z 
      0       0       0       0       0       0       0       0       0       0 

It’s really good that there are no NA’s but we have to be careful of the 0 in the numeric columns.

diamonds %>% select(carat, x, y, z) %>% arrange(x, y, z)
# A tibble: 53,940 × 4
   carat     x     y     z
   <dbl> <dbl> <dbl> <dbl>
 1  1     0     0     0   
 2  1.14  0     0     0   
 3  1.56  0     0     0   
 4  1.2   0     0     0   
 5  2.25  0     0     0   
 6  0.71  0     0     0   
 7  0.71  0     0     0   
 8  1.07  0     6.62  0   
 9  0.2   3.73  3.68  2.31
10  0.2   3.73  3.71  2.33
# ℹ 53,930 more rows

Diamonds cannot have a x(length), y(width), z(depth) of 0 and have weight. So let’s replace these values with NA or we can remove them out completely too.

diamonds %>% mutate(x = if_else(x == "0", NA, x),
                    y = if_else(y == "0", NA, y),
                    z = if_else(y == "0", NA, z)) %>% 
  datatable()

Now lets visualize the distribution of the diamonds.

frequency_poly <- diamonds %>% ggplot(aes(carat)) + geom_freqpoly(binwidth = 0.05)

ggplotly(frequency_poly)
Figure 1: Frequency polygon plot

From the Figure 1 we can observe that - Most of the diamonds are between 0.2 to 1.5 carats. - There are peaks which means higher number of diamonds at whole and common fractions.

My general knowledge is that the weight i.e, carat of the diamond influences the price most. Let’s visualize that.

diamonds %>% ggplot(aes(carat, price)) + geom_hex(bins = 50)

The price tends to follow exponential curve the log2() curve, we can confirm this by another graph.

diamonds %>% filter(carat < 2.5) %>% 
  mutate(log_price = log10(price),
                    log_carat = log10(carat)) %>% 
  ggplot(aes(log_carat, log_price)) + geom_hex(bins = 50) +
  geom_smooth(method = "lm", formula = y ~ splines::bs(x, 3),
              se = FALSE, linewidth = 1.5)
Figure 2: Log of carat vs Log of Price at base 2

The above Figure 2 shows that once we apply log2() to both price and carat the relationship mostly looks to be linear.

diamonds %>% filter(carat <= 2.5) %>% ggplot(aes(carat, price)) +
  geom_point(alpha = 0.1, aes(color = price)) +
  geom_smooth(method = "lm", formula = y ~ splines::bs(x, 3),
              se = FALSE, linewidth = 1.5) +
  scale_color_viridis_c()

We can see that price jumps when the weight is exactly or greater than to the whole and common fractions such as 0.5, 1.0, 1.5 and 2.

library(patchwork)

plot_parameter <- function(param){
  ggplot(diamonds, aes(fct_reorder({{param}}, price), price)) +
    geom_boxplot() + stat_summary(fun = mean, geom = "point") +
    labs(x = as_label(substitute(param)))
}

(plot_parameter(cut) + plot_parameter(color)) /
  (plot_parameter(clarity))

Low quality diamonds with Fair cut and low quality color seems to have very high price. So now lets use tidymodels to model the data using rand_forest

1.3 Building a model

As every parameter in the data is important for the price prediction we are going to keep all the columns intact.

library(tidymodels)
set.seed(2023)

diamonds_2 <- diamonds %>% select(-depth) %>% 
  mutate(price = log2(price), carat = log2(carat))

diamonds_split <- initial_split(diamonds_2, strata = carat, prop = 0.8)
diamonds_split
<Training/Testing/Total>
<43150/10790/53940>
diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)

I am using strata with carat as most of the diamonds are not properly distributed yet all diamonds of different weight should be well represented.

diamonds_folds <- vfold_cv(diamonds_train, strata = carat)
diamonds_folds
#  10-fold cross-validation using stratification 
# A tibble: 10 × 2
   splits               id    
   <list>               <chr> 
 1 <split [38833/4317]> Fold01
 2 <split [38833/4317]> Fold02
 3 <split [38834/4316]> Fold03
 4 <split [38835/4315]> Fold04
 5 <split [38835/4315]> Fold05
 6 <split [38835/4315]> Fold06
 7 <split [38836/4314]> Fold07
 8 <split [38836/4314]> Fold08
 9 <split [38836/4314]> Fold09
10 <split [38837/4313]> Fold10

I think rand_forest will work better on this data set but lets compare both Linear Regression models and Random Forest models.

lm_spec <- linear_reg() %>% set_engine("glm")
lm_spec
Linear Regression Model Specification (regression)

Computational engine: glm 
rf_spec <- rand_forest(trees = 1000) %>% 
  set_mode("regression") %>% 
  set_engine("ranger")
rf_spec
Random Forest Model Specification (regression)

Main Arguments:
  trees = 1000

Computational engine: ranger 

We still need to maniplulate some parts of the data like price and carat so that they are optimised which can be done using recipe library.

base_recp <- 
  recipe(price ~ ., data = diamonds_train) %>% 
  step_normalize(all_numeric_predictors())

ind_recp <- base_recp %>% 
  step_dummy(all_nominal_predictors())

spline_recp <- ind_recp %>% 
  step_bs(carat)

Next let’s start putting together a tidymodels workflow(), a helper object to help manage modeling pipelines with pieces that fit together like Lego blocks.

diamonds_set <- 
  workflow_set(
    list(base_recp, ind_recp, spline_recp),
    list(lm_spec, rf_spec))

diamonds_set
# A workflow set/tibble: 6 × 4
  wflow_id             info             option    result    
  <chr>                <list>           <list>    <list>    
1 recipe_1_linear_reg  <tibble [1 × 4]> <opts[0]> <list [0]>
2 recipe_1_rand_forest <tibble [1 × 4]> <opts[0]> <list [0]>
3 recipe_2_linear_reg  <tibble [1 × 4]> <opts[0]> <list [0]>
4 recipe_2_rand_forest <tibble [1 × 4]> <opts[0]> <list [0]>
5 recipe_3_linear_reg  <tibble [1 × 4]> <opts[0]> <list [0]>
6 recipe_3_rand_forest <tibble [1 × 4]> <opts[0]> <list [0]>

Let’s fit the two models we prepared for the data. First code block contains linear regression model and the second contains the random_forest model.

doParallel::registerDoParallel()

diamonds_rs <- 
  workflow_map(
    diamonds_set,
    "fit_resamples",
    resamples = diamonds_folds
  )
diamonds_rs
# A workflow set/tibble: 6 × 4
  wflow_id             info             option    result   
  <chr>                <list>           <list>    <list>   
1 recipe_1_linear_reg  <tibble [1 × 4]> <opts[1]> <rsmp[+]>
2 recipe_1_rand_forest <tibble [1 × 4]> <opts[1]> <rsmp[+]>
3 recipe_2_linear_reg  <tibble [1 × 4]> <opts[1]> <rsmp[+]>
4 recipe_2_rand_forest <tibble [1 × 4]> <opts[1]> <rsmp[+]>
5 recipe_3_linear_reg  <tibble [1 × 4]> <opts[1]> <rsmp[+]>
6 recipe_3_rand_forest <tibble [1 × 4]> <opts[1]> <rsmp[+]>

1.4 Evaluating a model

We can evaluate model by using autoplot and collect_metrics functions.

autoplot(diamonds_rs)

In the plot it seems that the difference between the rand_forest and linear_reg is very high but when we look at the metrics table we realise it’s not that much.

collect_metrics(diamonds_rs)
# A tibble: 12 × 9
   wflow_id         .config preproc model .metric .estimator  mean     n std_err
   <chr>            <chr>   <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
 1 recipe_1_linear… Prepro… recipe  line… rmse    standard   0.194    10 9.25e-4
 2 recipe_1_linear… Prepro… recipe  line… rsq     standard   0.983    10 1.58e-4
 3 recipe_1_rand_f… Prepro… recipe  rand… rmse    standard   0.137    10 1.17e-3
 4 recipe_1_rand_f… Prepro… recipe  rand… rsq     standard   0.991    10 1.33e-4
 5 recipe_2_linear… Prepro… recipe  line… rmse    standard   0.194    10 9.25e-4
 6 recipe_2_linear… Prepro… recipe  line… rsq     standard   0.983    10 1.58e-4
 7 recipe_2_rand_f… Prepro… recipe  rand… rmse    standard   0.140    10 1.44e-3
 8 recipe_2_rand_f… Prepro… recipe  rand… rsq     standard   0.991    10 1.65e-4
 9 recipe_3_linear… Prepro… recipe  line… rmse    standard   0.183    10 1.01e-3
10 recipe_3_linear… Prepro… recipe  line… rsq     standard   0.984    10 1.67e-4
11 recipe_3_rand_f… Prepro… recipe  rand… rmse    standard   0.137    10 1.06e-3
12 recipe_3_rand_f… Prepro… recipe  rand… rsq     standard   0.991    10 1.19e-4

From the metrics table we can see that recipe_1_rand_forest seems to perform the best.

final_fit <-
  extract_workflow(diamonds_rs, "recipe_1_rand_forest") %>%
  fit(diamonds_train)


ranger_model <- pull_workflow_fit(final_fit)
ranger_model
parsnip model object

Ranger result

Call:
 ranger::ranger(x = maybe_data_frame(x), y = y, num.trees = ~1000,      num.threads = 1, verbose = FALSE, seed = sample.int(10^5,          1)) 

Type:                             Regression 
Number of trees:                  1000 
Sample size:                      43150 
Number of independent variables:  8 
Mtry:                             2 
Target node size:                 5 
Variable importance mode:         none 
Splitrule:                        variance 
OOB prediction error (MSE):       0.01859615 
R squared (OOB):                  0.9913472 

Let’s fit the test set to the model.

final_predic <- predict(object = final_fit,
                        new_data = diamonds_test)

final_predic
# A tibble: 10,790 × 1
   .pred
   <dbl>
 1  8.85
 2  8.64
 3  8.66
 4  8.72
 5  8.70
 6  8.62
 7  8.87
 8  8.73
 9  8.83
10  8.63
# ℹ 10,780 more rows