Introduction

Motor vehicle accidents are an important part of traffic safety research. Analyzing the factors contributing to accidents and accident severity is critical for enhancing road safety standards. In this post, traffic accident data patterns will be explored and studied using machine-learning analysis techniques. Data processing, visualization, and modelling will be conducted using the R language for statistical computing and visualization.

Preliminaries

We’ll start by loading the main R libraries that will be used for data processing, visualization, and modelling.

# Load libraries
library(tidyverse)
library(scales)
library(lubridate)
library(plotly)
library(gridExtra)
library(tidytext)
library(modelr)
library(caret)
library(ROSE)
library(glmnet)
library(rpart)
library(rpart.plot)
library(randomForest)

Let’s create a plot theme to streamline generation of graphical presentations of the data using the ggplot2 package.

my_theme <- function(){ 
    
  theme_bw() %+replace%    # replace elements we want to change
    
  theme(
    plot.title = element_text(
      color = 'black',
      size = 12,
      face = 'bold',
      hjust = 0
     ),
     plot.subtitle = element_text(
       color = 'black',
       size = 12,
       face = 'bold',
       hjust = 0
     ),
     axis.title = element_text(
       color = 'black',
       size = 11
     ),
     axis.text = element_text(
       color = 'black',
       size = 11
     )
  )
}

Load Data

The data consist of traffic accident records across the contiguous United States from January 2016 through March 2023. The data were collected using several data providers, including multiple APIs that provide streaming traffic event data. These APIs broadcast traffic events captured by a variety of entities, such as the US and state departments of transportation, law enforcement agencies, traffic cameras, and traffic sensors within the road-networks. Currently, there are almost 8 million accident records in the dataset. For more information about this dataset, please visit here.

# Load data
dat <- read_csv(
  'data/US_Accidents_March23.csv',
  col_types = cols(.default = col_character())
) %>% 
type_convert()

Let’s inspect the general structure of the data.

# Inspect data
str(dat)
## spc_tbl_ [7,728,394 × 46] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ ID                   : chr [1:7728394] "A-1" "A-2" "A-3" "A-4" ...
##  $ Source               : chr [1:7728394] "Source2" "Source2" "Source2" "Source2" ...
##  $ Severity             : num [1:7728394] 3 2 2 3 2 3 2 3 2 3 ...
##  $ Start_Time           : POSIXct[1:7728394], format: "2016-02-08 05:46:00" "2016-02-08 06:07:59" ...
##  $ End_Time             : POSIXct[1:7728394], format: "2016-02-08 11:00:00" "2016-02-08 06:37:59" ...
##  $ Start_Lat            : num [1:7728394] 39.9 39.9 39.1 39.7 39.6 ...
##  $ Start_Lng            : num [1:7728394] -84.1 -82.8 -84 -84.2 -84.2 ...
##  $ End_Lat              : num [1:7728394] NA NA NA NA NA NA NA NA NA NA ...
##  $ End_Lng              : num [1:7728394] NA NA NA NA NA NA NA NA NA NA ...
##  $ Distance(mi)         : num [1:7728394] 0.01 0.01 0.01 0.01 0.01 0.01 0 0.01 0 0.01 ...
##  $ Description          : chr [1:7728394] "Right lane blocked due to accident on I-70 Eastbound at Exit 41 OH-235 State Route 4." "Accident on Brice Rd at Tussing Rd. Expect delays." "Accident on OH-32 State Route 32 Westbound at Dela Palma Rd. Expect delays." "Accident on I-75 Southbound at Exits 52 52B US-35. Expect delays." ...
##  $ Street               : chr [1:7728394] "I-70 E" "Brice Rd" "State Route 32" "I-75 S" ...
##  $ City                 : chr [1:7728394] "Dayton" "Reynoldsburg" "Williamsburg" "Dayton" ...
##  $ County               : chr [1:7728394] "Montgomery" "Franklin" "Clermont" "Montgomery" ...
##  $ State                : chr [1:7728394] "OH" "OH" "OH" "OH" ...
##  $ Zipcode              : chr [1:7728394] "45424" "43068-3402" "45176" "45417" ...
##  $ Country              : chr [1:7728394] "US" "US" "US" "US" ...
##  $ Timezone             : chr [1:7728394] "US/Eastern" "US/Eastern" "US/Eastern" "US/Eastern" ...
##  $ Airport_Code         : chr [1:7728394] "KFFO" "KCMH" "KI69" "KDAY" ...
##  $ Weather_Timestamp    : POSIXct[1:7728394], format: "2016-02-08 05:58:00" "2016-02-08 05:51:00" ...
##  $ Temperature(F)       : num [1:7728394] 36.9 37.9 36 35.1 36 37.9 34 34 33.3 37.4 ...
##  $ Wind_Chill(F)        : num [1:7728394] NA NA 33.3 31 33.3 35.5 31 31 NA 33.8 ...
##  $ Humidity(%)          : num [1:7728394] 91 100 100 96 89 97 100 100 99 100 ...
##  $ Pressure(in)         : num [1:7728394] 29.7 29.6 29.7 29.6 29.6 ...
##  $ Visibility(mi)       : num [1:7728394] 10 10 10 9 6 7 7 7 5 3 ...
##  $ Wind_Direction       : chr [1:7728394] "Calm" "Calm" "SW" "SW" ...
##  $ Wind_Speed(mph)      : num [1:7728394] NA NA 3.5 4.6 3.5 3.5 3.5 3.5 1.2 4.6 ...
##  $ Precipitation(in)    : num [1:7728394] 0.02 0 NA NA NA 0.03 NA NA NA 0.02 ...
##  $ Weather_Condition    : chr [1:7728394] "Light Rain" "Light Rain" "Overcast" "Mostly Cloudy" ...
##  $ Amenity              : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Bump                 : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Crossing             : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Give_Way             : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Junction             : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ No_Exit              : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Railway              : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Roundabout           : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Station              : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Stop                 : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Traffic_Calming      : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Traffic_Signal       : logi [1:7728394] FALSE FALSE TRUE FALSE TRUE FALSE ...
##  $ Turning_Loop         : logi [1:7728394] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Sunrise_Sunset       : chr [1:7728394] "Night" "Night" "Night" "Night" ...
##  $ Civil_Twilight       : chr [1:7728394] "Night" "Night" "Night" "Day" ...
##  $ Nautical_Twilight    : chr [1:7728394] "Night" "Night" "Day" "Day" ...
##  $ Astronomical_Twilight: chr [1:7728394] "Night" "Day" "Day" "Day" ...
##  - attr(*, "problems")=<externalptr>

Let’s use the skimr package to get a more detailed summary of the data.

# Summary of data using skimr
skimr::skim(dat)
Data summary
Name dat
Number of rows 7728394
Number of columns 46
_______________________
Column type frequency:
character 17
logical 13
numeric 13
POSIXct 3
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
ID 0 1.00 3 9 0 7728394 0
Source 0 1.00 7 7 0 3 0
Description 5 1.00 2 678 0 3761570 0
Street 10869 1.00 1 59 0 249364 0
City 253 1.00 3 32 0 13678 0
County 0 1.00 3 30 0 1871 0
State 0 1.00 2 2 0 49 0
Zipcode 1915 1.00 5 10 0 825094 0
Country 0 1.00 2 2 0 1 0
Timezone 7808 1.00 10 11 0 4 0
Airport_Code 22635 1.00 4 4 0 2045 0
Wind_Direction 175206 0.98 1 8 0 24 0
Weather_Condition 173459 0.98 3 35 0 144 0
Sunrise_Sunset 23246 1.00 3 5 0 2 0
Civil_Twilight 23246 1.00 3 5 0 2 0
Nautical_Twilight 23246 1.00 3 5 0 2 0
Astronomical_Twilight 23246 1.00 3 5 0 2 0

Variable type: logical

skim_variable n_missing complete_rate mean count
Amenity 0 1 0.01 FAL: 7632060, TRU: 96334
Bump 0 1 0.00 FAL: 7724880, TRU: 3514
Crossing 0 1 0.11 FAL: 6854631, TRU: 873763
Give_Way 0 1 0.00 FAL: 7691812, TRU: 36582
Junction 0 1 0.07 FAL: 7157052, TRU: 571342
No_Exit 0 1 0.00 FAL: 7708849, TRU: 19545
Railway 0 1 0.01 FAL: 7661415, TRU: 66979
Roundabout 0 1 0.00 FAL: 7728145, TRU: 249
Station 0 1 0.03 FAL: 7526493, TRU: 201901
Stop 0 1 0.03 FAL: 7514023, TRU: 214371
Traffic_Calming 0 1 0.00 FAL: 7720796, TRU: 7598
Traffic_Signal 0 1 0.15 FAL: 6584622, TRU: 1143772
Turning_Loop 0 1 0.00 FAL: 7728394

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Severity 0 1.00 2.21 0.49 1.00 2.00 2.00 2.00 4.00 ▁▇▁▂▁
Start_Lat 0 1.00 36.20 5.08 24.55 33.40 35.82 40.08 49.00 ▂▇▇▆▂
Start_Lng 0 1.00 -94.70 17.39 -124.62 -117.22 -87.77 -80.35 -67.11 ▆▁▃▇▅
End_Lat 3402762 0.56 36.26 5.27 24.57 33.46 36.18 40.18 49.08 ▃▇▇▇▂
End_Lng 3402762 0.56 -95.73 18.11 -124.55 -117.75 -88.03 -80.25 -67.11 ▇▁▃▇▅
Distance(mi) 0 1.00 0.56 1.78 0.00 0.00 0.03 0.46 441.75 ▇▁▁▁▁
Temperature(F) 163853 0.98 61.66 19.01 -89.00 49.00 64.00 76.00 207.00 ▁▁▇▁▁
Wind_Chill(F) 1999019 0.74 58.25 22.39 -89.00 43.00 62.00 75.00 207.00 ▁▁▇▁▁
Humidity(%) 174144 0.98 64.83 22.82 1.00 48.00 67.00 84.00 100.00 ▁▃▆▇▇
Pressure(in) 140679 0.98 29.54 1.01 0.00 29.37 29.86 30.03 58.63 ▁▁▇▁▁
Visibility(mi) 177098 0.98 9.09 2.69 0.00 10.00 10.00 10.00 140.00 ▇▁▁▁▁
Wind_Speed(mph) 571233 0.93 7.69 5.42 0.00 4.60 7.00 10.40 1087.00 ▇▁▁▁▁
Precipitation(in) 2203586 0.71 0.01 0.11 0.00 0.00 0.00 0.00 36.47 ▇▁▁▁▁

Variable type: POSIXct

skim_variable n_missing complete_rate min max median n_unique
Start_Time 0 1.00 2016-01-14 20:18:33 2023-03-31 23:30:00 2020-11-11 08:40:22 5801064
End_Time 0 1.00 2016-02-08 06:37:08 2023-03-31 23:59:00 2020-11-11 15:56:29 6463024
Weather_Timestamp 120228 0.98 2016-01-14 19:51:00 2023-03-31 23:53:00 2020-11-10 13:53:00 941331

Let’s use the plot_missing() function from the DataExplorer package to examine the data for missing values.

# Inspect missing data using DataExplorer
DataExplorer::plot_missing(dat)

Prep Data

1. Drop unnecessary variables

In R, missing values are represented by the symbol NA (not available). Let’s remove variables with a high NA proportion. Assuming that the values are missing randomly, this should not adversely affect our analyses.

These are the variables to drop:

# Create list of variables to drop
drop_na_cols <- c(
  'End_Lat', 'End_Lng', 'Wind_Chill(F)', 'Precipitation(in)'
)

A few variables like “ID”, “Source”, and “Timezone” provide little or no insight about traffic accidents. Let’s remove these variables from the data.

# Create list of non useful variables
not_useful <- c(
  'ID', 'Source', 'Timezone', 'Airport_Code', 'Weather_Timestamp', 
  'Wind_Direction', 'Description'
)

# Drop high NA variables and non useful variables
dat_drop <- dat %>% select(-all_of(drop_na_cols), -all_of(not_useful))

2. Rename variables

Let’s rename a few variables to avoid potential naming convention errors during the analyses.

dat_drop <-  dat_drop %>%
  rename(
    Distance = 'Distance(mi)',
    Temperature = 'Temperature(F)', 
    Humidity = 'Humidity(%)', 
    Pressure = 'Pressure(in)',
    Visibility = 'Visibility(mi)',
    Wind_Speed = 'Wind_Speed(mph)'
)

4. Drop missing weather condition

When weather condition is missing, there is a good chance that other weather related variables also will be missing. Let’s check to see if this is the case.

dat_time %>%
  filter(is.na(Weather_Condition)) %>%
  select(Temperature:Weather_Condition) %>%
  head(10) 
## # A tibble: 10 × 6
##    Temperature Humidity Pressure Visibility Wind_Speed Weather_Condition
##          <dbl>    <dbl>    <dbl>      <dbl>      <dbl> <chr>            
##  1        48.2       93     29.5         10        9.2 <NA>             
##  2        NA         NA     NA           NA       NA   <NA>             
##  3        95         20     29.9         10        6.9 <NA>             
##  4        91.4       28     29.9         10       15   <NA>             
##  5        NA         NA     NA           NA       NA   <NA>             
##  6        NA         NA     NA           NA       NA   <NA>             
##  7        NA         NA     NA           NA       NA   <NA>             
##  8        NA         NA     NA           NA       NA   <NA>             
##  9        NA         NA     NA           NA       NA   <NA>             
## 10        NA         NA     NA           NA       NA   <NA>

This appears to be true so let’s remove all records containing NA for the weather condition level.

dat_weather <- dat_time %>% filter(!is.na(Weather_Condition))

6. Modify variable type

Let’s convert severity to a categorical variable and logical variables to character.

dat_add <- dat_add %>% 
  mutate(
    Severity = as.character(Severity)
  ) %>% 
  mutate_if(is.logical, as.character)

7. Handle missing values in continuous variables

Let’s replace missing continuous variables with the mean.

# Replace missing continuous variables with the mean.
dat_mean <- dat_add %>%
  mutate_if(is.numeric, ~ replace_na(., mean(., na.rm = T)))

# Inspect continuous variables after replacement
summary(dat_mean %>% select_if(is.numeric))
##     Duration           Start_Lat       Start_Lng          Distance       
##  Min.   :      1.2   Min.   :24.55   Min.   :-124.62   Min.   :  0.0000  
##  1st Qu.:     31.4   1st Qu.:33.38   1st Qu.:-117.22   1st Qu.:  0.0000  
##  Median :     74.8   Median :35.79   Median : -87.79   Median :  0.0290  
##  Mean   :    445.3   Mean   :36.18   Mean   : -94.71   Mean   :  0.5577  
##  3rd Qu.:    125.0   3rd Qu.:40.10   3rd Qu.: -80.38   3rd Qu.:  0.4600  
##  Max.   :2812939.0   Max.   :49.00   Max.   : -67.11   Max.   :441.7500  
##   Temperature        Humidity         Pressure       Visibility    
##  Min.   :-89.00   Min.   :  1.00   Min.   : 0.00   Min.   :  0.00  
##  1st Qu.: 49.00   1st Qu.: 48.00   1st Qu.:29.37   1st Qu.: 10.00  
##  Median : 64.00   Median : 67.00   Median :29.86   Median : 10.00  
##  Mean   : 61.68   Mean   : 64.84   Mean   :29.54   Mean   :  9.09  
##  3rd Qu.: 76.00   3rd Qu.: 84.00   3rd Qu.:30.03   3rd Qu.: 10.00  
##  Max.   :207.00   Max.   :100.00   Max.   :58.63   Max.   :140.00  
##    Wind_Speed      
##  Min.   :   0.000  
##  1st Qu.:   4.600  
##  Median :   7.000  
##  Mean   :   7.686  
##  3rd Qu.:  10.000  
##  Max.   :1087.000

8. Handle missing values in categorical variables

Let’s inspect the processed data for NA values associated with the categorical variables.

# Inspect for NA values in categorical variables.
DataExplorer::plot_missing(dat_mean)

Let’s remove those few records that include NA’s.

# Remove rows with NA's using drop_na()
dat_final <- dat_mean %>% drop_na()

Visualize Data

Let’s explore the final processed data to better understand the characteristics of the data and patterns that may be contained in the data.

# Load processed data
dat <- dat_final %>%
mutate(
  Severity = factor(Severity),
  Year = factor(Year),
  Wday = factor(Wday)
) %>%
mutate_if(is.logical, factor) %>%
mutate_if(is.character, factor)

1. Top 10 states

Let’s create a map showing the 10 states that have the highest number of accidents.

# Transform data from the maps package into a data frame
states <- map_data('state') %>%
  as_tibble() %>%
  select(long, lat, group, region)

# Load states with abbreviations from csv file
states_abb <- read_csv('data/states.csv') %>%
  mutate(State = tolower(State)) %>%
  select(State, Code) %>%
  rename(State_full = State)

# Get accident count by states
accident_count <- dat %>%
  count(State) %>%
  left_join(states_abb, by = c('State' = 'Code'))

# Add accident count to states data frame
states <- states %>%
  left_join(accident_count, by = c('region' = 'State_full'))

# Get top 10 states
top_10 <- accident_count %>%
  arrange(desc(n)) %>%
  head(10)

top_10 <- top_10$State %>% unlist()

top_10_map <- states %>%
  filter(State %in% top_10)

top_10_label <- top_10_map %>%
  group_by(region, State) %>%
  summarize(long = mean(long), lat = mean(lat)) %>%
  ungroup()

# Create map
ggplot(states, aes(x = long, y = lat, group = group)) +
  geom_polygon(aes(fill = n), color = '#636363', size = 0.1) +
  geom_polygon(data = top_10_map, color = 'red', fill = NA, size = 0.8) +
  scale_fill_gradient(low = '#fee5d9', high = '#de2d26',
                      name = 'Accident Count',
                      labels = unit_format(unit = 'K', scale = 1e-03)) +
  ggrepel::geom_label_repel(data = top_10_label, aes(label = State, group = 1)) +
  theme_minimal() +
  coord_quickmap() +
  labs(title = 'Accident distribution in the U.S.',
       x = 'Longitude',
       y = 'Latitude')

Now, let’s create a bar chart showing the accident count for these 10 states.

# Create bar chart showing accident count for top 10 states
dat %>% 
  filter(State %in% top_10) %>%
  count(State) %>%
  ggplot(aes(reorder(State, n), n)) +
  geom_col() +
  geom_label(aes(label = n), nudge_y = -30000) +
  labs(title = 'Top 10 States with the most accidents',
       x = NULL,
       y = 'Number of accidents') +
  scale_x_discrete(labels =
    rev(
      c('California', 'Texas', 'Florida', 'South Carolina',
        'North Carolina', 'New York', 'Pennsylvania',
        'Michigan', 'Illinois', 'Georgia')
    )
  ) +
  scale_y_continuous(breaks = seq(0, 1800000, 200000),
                     labels = unit_format(unit = 'K', scale = 1e-03)) +
  coord_flip() +
  my_theme()

2. Accident distance

The “Distance” variable in the data is the length of the road extent affected by the accident. Let’s inspect distance and severity levels to see if there is a relationship between the two variables.

# Create bar chart showing distance as a function of severity
dat %>%
  group_by(Severity) %>%
  summarize(prop = mean(Distance)) %>%
  ggplot(aes(Severity, prop, fill = !Severity %in% c(3, 4))) +
    geom_col() +
    labs(title = 'More severe accidents tend to affect longer road distance',
         y = 'Average affected distance (mi)') +
    scale_fill_discrete(name = 'Severity',
                        labels = c('More Severe: 3 or 4', 'Less Severe: 1 or 2')) +
  my_theme()

3. Accident count

A plot of severity level distribution in each year is shown below. Based on the plot, we see that severity level 2 increased between 2018 and 2023 while severity level 3 decreased during this same period of time. Most of the accidents are classified as either level 2 or level 3.

# Create bar chart showing severity level distribution in each year
dat %>%
  group_by(Year, Severity) %>%
  count() %>%
  group_by(Year) %>%
  mutate(sum = sum(n)) %>%
  mutate(Proportion = n / sum) %>%
  ggplot(aes(Severity, Proportion)) +
  geom_col(aes(fill = Year), position = 'dodge') +
  labs(title = 'Severity proportion changes by year',
       x = 'Severity',
       y = 'Proportion') +
  scale_y_continuous(labels = percent) +
  my_theme()

4. Temporal behavior

A plot of accident count by day and month is shown below. Based on the plot, accident count increases between July and December and decreases between January and July. More accidents occur during weekdays compared to weekends.

# Create line plot of accident count by month
g_top <- dat %>%
  count(Month) %>%
  ggplot(aes(Month, n)) +
  geom_line(aes(group = 1)) +
  geom_point() +
  labs(title = 'Pattern between accident counts and month & day of the week',
       x = NULL,
       y = 'Count') +
  scale_x_discrete(labels = c('Jan', 'Feb', 'Mar', 'Apr', 'May',
                              'Jun', 'Jul', 'Aug', 'Sep', 'Oct',
                              'Nov', 'Dec')) +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  my_theme()

# Create bar chart of accident count by day and month
g_bottom <- dat %>%
  ggplot(aes(Month, fill = Wday)) +
  geom_bar(position = 'dodge') +
  scale_fill_manual(values = c('deepskyblue1', 'coral1', 'coral1', 'coral1',
                               'coral1', 'coral1', 'deepskyblue1'),
                    name = 'Day of the week',
                    labels = c('Sun', 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat')) +
  guides(fill = guide_legend(nrow = 1)) +
  scale_x_discrete(labels = c('Jan', 'Feb', 'Mar', 'Apr', 'May',
                              'Jun', 'Jul', 'Aug', 'Sep', 'Oct',
                              'Nov', 'Dec')) +
  labs(y = 'Count') +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  my_theme() +
  theme(legend.position = 'bottom')

# Combine plots
grid.arrange(g_top, g_bottom, heights = c(1/4, 3/4))

A plot of accident count by hour of the day is shown below. Based on the plot, most accidents occur between 7am and 8am and between 4pm and 5 pm.

# Create bar chart of hourly accident count
left <- dat %>%
  ggplot(aes(Hour, fill = !Hour %in% c('07', '08', '16', '17'))) +
    geom_bar(show.legend = F) +
    labs(title = 'Hourly Distribution of Accidents',
         x = 'Hour',
         y = 'No of Accidents') +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  my_theme()

# Create frequency plot of hourly accident count
right <- dat %>%
  ggplot(aes(Hour, color = Wday %in% c('1', '7'),
         group = Wday %in% c('1', '7'))) +
  geom_freqpoly(stat = 'count') +
  scale_color_discrete(name = 'Is weekdays?',
                       labels = c('No', 'Yes')) +
  labs(title = '',
       y = NULL) +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  my_theme()

# Combine plots
grid.arrange(left, right, widths = c(1/2, 1/2))

5. Weather condition

A plot of accident proportion for each severity level is shown below. It appears that more accidents occur during fair weather conditions.

# Get accident proportion for each severity level
weather <- dat %>%
  group_by(Severity) %>%
  count(Weather_Condition) %>%
  mutate(n = n / sum(n)) %>%
  filter(n > 0.02) %>%
  ungroup()
weather <- weather$Weather_Condition

# Create plot showing accident proportion for each severity level
dat %>%
  filter(Weather_Condition %in% weather) %>%
  group_by(Severity) %>%
  count(Weather_Condition) %>%
  mutate(n = n / sum(n)) %>%
  ungroup() %>%
  ggplot(aes(reorder_within(Weather_Condition, n, Severity), n)) +
  geom_col(aes(fill = !Weather_Condition == 'Fair'), show.legend = F) +
  facet_wrap(~ Severity, scales = 'free_y') +
  coord_flip() +
  scale_x_reordered() +
  scale_y_continuous(breaks = seq(0, 0.6, 0.1),
                     labels = percent) +
  labs(title = 'Weather condition on accident severity',
       x = 'Weather Condition',
       y = 'Proportion') +
  my_theme()

Pre-process Data for Modelling

Data preprocessing is the process in which the data to be used for machine learning are manipulated (e.g., transformed, encoded, etc.) such that it can quickly be parsed and analyzed. Additionally, irrelevant or redundant data are deleted or modified in support of the predictive analysis process.

1. Select one State

Let’s subset our data on the state of California to manage computational overhead.

# Subset data on California
dat_CA <- dat %>% filter(State == 'CA') %>% select(-State)

2. Drop weather condition levels

Several weather condition categories only have a few records, which can be probamatic when splitting the dataset for modelling. For example, some levels may appear in the training dataset but may not appear in the test dataset.

Let’s create a plot of accident proportion by weather condition.

# Create plot of accident proportion by weather condition
ggplot(dat_CA, aes(Weather_Condition, ..prop.., group = Severity)) +
  geom_bar(aes(fill = Severity), position = 'dodge') +
  scale_y_continuous(labels = percent) +
  labs(title = 'Weather condition has categorgies with few observations',
       x = 'Weather Condition',
       y = 'Proportion') +
  my_theme() +
  theme(axis.text.x = element_text(size = 9, angle = 60, vjust = 0.6))

Let’s remove weather condition categories with less than 50 records.

dat_CA %>%
  count(Weather_Condition) %>%
  filter(n < 50) %>%
  select(Weather_Condition, n) %>%
  data.frame()
##               Weather_Condition  n
## 1                  Blowing Sand  2
## 2                  Blowing Snow 10
## 3               Drizzle / Windy  4
## 4                   Dust Whirls  1
## 5                     Duststorm  3
## 6                          Hail  7
## 7             Heavy Rain Shower  2
## 8                   Heavy Smoke  1
## 9            Heavy Snow / Windy 29
## 10        Heavy T-Storm / Windy  5
## 11 Heavy Thunderstorms and Rain 21
## 12        Light Drizzle / Windy 11
## 13           Light Freezing Fog 22
## 14          Light Freezing Rain  5
## 15                   Light Hail  2
## 16                   Light Haze  6
## 17    Light Rain Shower / Windy 11
## 18           Light Rain Showers 29
## 19            Light Snow Shower  4
## 20           Light Snow Showers  1
## 21      Light Snow with Thunder  1
## 22           Light Thunderstorm  1
## 23                 Mist / Windy 13
## 24                  Partial Fog 29
## 25                  Rain Shower 34
## 26          Rain Shower / Windy  3
## 27                 Sand / Windy  1
## 28                   Small Hail  4
## 29                Smoke / Windy 34
## 30                 Snow / Windy 49
## 31             Snow and Thunder 14
## 32                      Squalls  4
## 33              Squalls / Windy  9
## 34              T-Storm / Windy 17
## 35              Thunder / Windy  4
## 36             Thunder and Hail  5
## 37                 Thunderstorm 21
## 38                 Volcanic Ash 12
## 39      Widespread Dust / Windy 13
## 40           Wintry Mix / Windy  1

# Get weather condition categories with less than 50 records
drop_weather <- dat_CA %>%
  count(Weather_Condition) %>%
  filter(n < 50) %>%
  select(Weather_Condition)
drop_weather <- drop_weather$Weather_Condition %>% unlist()

# Remove weather condition categories with less than 50 records
dat_CA <- dat_CA %>% 
  filter(!(Weather_Condition %in% drop_weather)) %>% 
  mutate(Weather_Condition = factor(Weather_Condition))

3. Group 4 severity levels into 2 levels

Most of the accidents are classified as either level 2 or level 3 as shown in the plot below.

# Create plot of accident count by accident severity 
ggplot(dat_CA, aes(Severity, fill = !Severity %in% c(3, 4))) +
  geom_bar() +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  scale_fill_discrete(name = 'Severity',
                      labels = c('Severe: 3 or 4', 'Not Severe: 1 or 2')) +
  labs(title = 'Unbalanced severity levels',
       y = 'Count') +
  my_theme()

Let’s group the 4 severity levels into 2 levels. Level 1 and level 2 will be grouped as “Not Severe”, and level 3 and level 4 will be grouped as “Severe”.

# Group accident severity levels
dat_label <- dat_CA %>%
  mutate(Status =
    factor(
      ifelse(Severity == '3' | Severity == '4', 'Severe', 'Not Severe'), 
      levels = c('Not Severe', 'Severe')
    )
  ) %>%
  select(-Severity)

Let’s create a plot of accident count by accident severity for the 2 levels.

# Create plot of accident count by accident severity for the 2 levels
ggplot(dat_label, aes(Status, fill = !Status == 'Severe')) +
  geom_bar() +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  scale_fill_discrete(name = 'Severity',
                      labels = c('Severe', 'Not Severe')) +
  labs(title = 'More balanced severity levels',
        x = 'Severity',
        y = 'Count') +
  my_theme()

4. Near Zero-Variance Predictors

A few variables have near zero or near-zero variance (NZV). A zero variance variable is one whose values are all the same constant variable and a NZV variable is one where almost all values are constant and only a few have values that differ from that constant.

The variables with NZV in our dataset include the following:

nzv <- nearZeroVar(dat_label, saveMetrics = T)
nzv[nzv$nzv,]
##                   freqRatio percentUnique zeroVar  nzv
## Visibility         21.07858  3.705935e-03   FALSE TRUE
## Amenity           134.63999  1.176487e-04   FALSE TRUE
## Bump             1688.83698  1.176487e-04   FALSE TRUE
## Give_Way          802.01181  1.176487e-04   FALSE TRUE
## No_Exit           919.39848  1.176487e-04   FALSE TRUE
## Railway           100.48505  1.176487e-04   FALSE TRUE
## Roundabout      28812.15254  1.176487e-04   FALSE TRUE
## Station            41.66366  1.176487e-04   FALSE TRUE
## Stop               30.53067  1.176487e-04   FALSE TRUE
## Traffic_Calming  1155.44626  1.176487e-04   FALSE TRUE
## Turning_Loop        0.00000  5.882436e-05    TRUE TRUE

Because NZV variables may be probabmatic for machine learning and predicting, let’s remove these variables.

# Remove NZV variables
nzv_cols <- rownames(nzv[nzv$nzv,])
dat_label <- dat_label %>%
  select(-all_of(nzv_cols))

5. Partition

Let’s split the dataset into 2 partitions: a training dataset (70%) to build the various models and a test dataset (30%) to document final performance of each model.

# Set seed for reproducibility
set.seed(1515)

# Partition the data
dat_parts <- resample_partition(dat_label, c(train = 0.7, test = 0.3))
train_set <- as_tibble(dat_parts$train)
test_set <- as_tibble(dat_parts$test)

Build Models

1. Sampling

A plot of accident count by accident severity for the training dataset is shown below. The dataset is highly imbalanced, which may compromise the process of learning. With highly imbalanced datasets, a model tends to focus on the prevalent class and ignores the rare events.

# Create plot of accident count by accident severity for the training dataset
ggplot(train_set, aes(Status)) +
  geom_bar(aes(fill = Status)) +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  labs(title = 'Unbalanced severity levels',
       y = 'Count') +
  my_theme()

Let’s apply some sampling techniques to make the data balanced. We will use both oversampling and undersampling to make the data balanced. This also will reduce the data size to a scale that is more manageable. The ovun.sample() function in the ROSE package creates possibly balanced samples by random over-sampling minority examples, under-sampling majority examples or a combination of over- and under-sampling.

# Apply oversampling and undersampling to the training dataset
new_train <- ovun.sample(
  Status ~ ., 
  data = train_set, 
  method = 'both', p = 0.5, N = 90000, seed = 1
)$data %>%
  as_tibble()

Let’s create a plot of accident count by accident severity for the new training dataset.

# Create plot of accident count by accident severity for the new training dataset
ggplot(new_train, aes(Status)) +
  geom_bar(aes(fill = Status)) +
  scale_y_continuous(labels = unit_format(unit = 'K', scale = 1e-03)) +
  labs(title = 'Balanced severity levels',
       y = 'Count') +
  my_theme()

2. Logistic regression

Because our response variable has 2 levels, “Severe” and “Not Severe”, it’s reasonable to choose logistic regression as one of our models. Logistic regression is a supervised machine learning algorithm mainly used for classification tasks where the goal is to predict the probability that an observation belongs to a given class or not.

We’ll perform perform stepwise logistic regression to reduce the complexity of the model without compromising its accuracy. Stepwise logistic regression consists of automatically selecting a reduced number of predictor variables for building the best performing logistic regression model. Let’s use the Akaike information criterion (AIC) for the variable selection process.

# Perform logistic regression
model_aic <- glm(Status ~ ., data = new_train, family = 'binomial')
model_aic <- step(model_aic)

Based on the AIC, the following variables are dropped from the regression model:

# These variables are dropped:
model_aic$anova[2:nrow(model_aic$anova), c(1, 6)] %>%
  as_tibble() %>%
  mutate(Step = str_sub(Step, start = 3)) %>%
  rename('Variables to drop' = Step)
## # A tibble: 3 × 2
##   `Variables to drop`      AIC
##   <chr>                  <dbl>
## 1 Day                   96012.
## 2 Civil_Twilight        96010.
## 3 Astronomical_Twilight 96008.

The final model formula is as follows:

# The final formula based on AIC value
model_aic$call
## glm(formula = Status ~ Year + Month + Hour + Wday + Duration + 
##     Start_Lat + Start_Lng + Distance + Temperature + Humidity + 
##     Pressure + Wind_Speed + Weather_Condition + Crossing + Junction + 
##     Traffic_Signal + Sunrise_Sunset + Nautical_Twilight, family = "binomial", 
##     data = new_train)

Let’s make predictions on the test dataset.

# Make predictions on test dataset
test_pred <- model_aic %>%
  predict(test_set, type = 'response') %>%
  data.frame() %>%
  bind_cols(test_set$Status) %>%
  rename(
    pred = 1,
    Status = 2
  ) %>%
  dplyr::select(Status, pred)

Let’s transform probability to response variable levels using a value of 0.5 as the cutoff.

# Transform probability to response variable levels
test_pred <- test_pred %>%
  mutate(pred = ifelse(pred > 0.5, 'Severe', 'Not Severe'))

We can review the performance of the logistic regresson through a confusion matrix. The elements of the confusion matrix can be utilized to find three key machine learning parameters: accuracy, sensitivity, and specificity. Accuracy measres how well the model can correctly predict the class of new observations. Sensitivity measures how well the model can correctly identify instances of the positive class, which in our case is an accident level of “Not Severe.” Specificity measures how well the model can correctly identify instances of the negative class, which in our case is an accident level of “Severe.”

cm_lr <- confusionMatrix(table(test_pred$pred, test_pred$Status))
tibble(
  'Accuracy'      = cm_lr$overall[[1]],
  'Sensitivity'   = cm_lr$byClass[[1]],
  'Specificity'   = cm_lr$byClass[[2]],
  'Positive term' = cm_lr$positive
)
## # A tibble: 1 × 4
##   Accuracy Sensitivity Specificity `Positive term`
##      <dbl>       <dbl>       <dbl> <chr>          
## 1    0.727       0.723       0.748 Not Severe

3. LASSO regression

Because we have so many variables in our dataset, it’s possible that some of the variables’ coefficients may be near zero in the final best model. Let’s try LASSO (Least Absolute Shrinkage and Selection Operator) logistic regression next. LASSO regression is a particular type of regularization that adds a penalty equal to the absolute value of the magnitude of coefficients. This penalty can force some variables to have a coefficient of zero.

We must pass in an x as a matrix and y as a vector. The model.matrix() function of the base stats package is useful for creating x. Not only does it produce a matrix corresponding to the predictor variables but it also automatically transforms any qualitative variables into dummy variables. The latter property is important because the glmnet() function of the glmnet package can only take numerical, quantitative inputs.

x_train <- model.matrix(Status ~ ., data = new_train)[,-1]
y_train = new_train %>%
  select(Status) %>%
  unlist()

# Fit lasso model on training data
model_lasso <- glmnet(
  x_train, y_train,
  family = 'binomial',
  alpha = 1
)

The change of variables’ coefficient is shown in the plot below.

To get the best LASSO model, we need to find the best tuning parameter lambda. We will use cross validation to find the best lambda.

# Perform cross validation
model_lambda <- cv.glmnet(x_train, y_train, family = 'binomial')

Now, let’s fit the LASSO model using the best lambda.

# Fit the LASSO model with the best lambda
lambda <- model_lambda$lambda.min
model_lasso <- glmnet(
  x_train, y_train,
  family = 'binomial',
  alpha = 1, lambda = lambda
)

Let’s make predictions on the test data using our new LASSO model.

# Make predictions on test data
x_test <- model.matrix(Status ~ ., data = test_set)[,-1]
test_pred2 <- model_lasso %>%
  predict(s = lambda, newx = x_test, type = 'response') %>%
  data.frame() %>%
  bind_cols(test_set$Status) %>%
  rename(
    pred = 1,
    Status = 2
  ) %>%
  dplyr::select(Status, pred)

Let’s transform probability to response variable levels using a value of 0.5 as the cutoff.

# Transform probability to response variable levels
test_pred <- test_pred %>%
  mutate(pred = ifelse(pred > 0.5, 'Severe', 'Not Severe'))

Let’s review the performance of LASSO regresson using the confusion matrix.

cm_lasso <- confusionMatrix(table(test_pred$pred, test_pred$Status))
tibble(
  'Accuracy'      = cm_lasso$overall[[1]],
  'Sensitivity'   = cm_lasso$byClass[[1]],
  'Specificity'   = cm_lasso$byClass[[2]],
  'Positive term' = cm_lasso$positive
)
## # A tibble: 1 × 4
##   Accuracy Sensitivity Specificity `Positive term`
##      <dbl>       <dbl>       <dbl> <chr>          
## 1    0.727       0.723       0.748 Not Severe
cm_lasso
## Confusion Matrix and Statistics
## 
##             
##              Not Severe Severe
##   Not Severe     307967  21189
##   Severe         117927  62910
##                                          
##                Accuracy : 0.7272         
##                  95% CI : (0.726, 0.7284)
##     No Information Rate : 0.8351         
##     P-Value [Acc > NIR] : 1              
##                                          
##                   Kappa : 0.3224         
##                                          
##  Mcnemar's Test P-Value : <2e-16         
##                                          
##             Sensitivity : 0.7231         
##             Specificity : 0.7480         
##          Pos Pred Value : 0.9356         
##          Neg Pred Value : 0.3479         
##              Prevalence : 0.8351         
##          Detection Rate : 0.6039         
##    Detection Prevalence : 0.6454         
##       Balanced Accuracy : 0.7356         
##                                          
##        'Positive' Class : Not Severe     
## 

The performance of LASSO logistic regression is similar to the previous normal logistic model.

4. Decision tree

Decision tree is a supervised learning technique that can be used for both classification and regression, but mostly it is preferred for classification tasks. It is a tree-structured classifier, where internal nodes represent the features of a dataset, branches represent the decision rules and each leaf node represents the outcome.

Next, let’s build a decison tree model using the rpart package.

# Build decision tree model
model_tree <- rpart(
  Status ~ ., data = new_train,
  method = 'class',
  minsplit = 20, cp = 0.001
)

Usually we can plot the decision tree to see all the nodes. To achieve a higher accuracy, we took many variables into account (set cp = 0.001), which makes the final tree quite complicated and not easily plotted.

# Plot tree
rpart.plot(model_tree, box.palette = 'RdBu', shadow.col = 'grey')

Let’s make predictions on the test data using the decision tree model.

# Make predictions on test data
test_pred <- model_tree %>%
  predict(test_set, type = 'class') %>%
  data.frame() %>%
  bind_cols(test_set$Status) %>%
  rename(
    pred = 1,
    Status = 2
  ) %>%
  dplyr::select(Status, pred)

Let’s review the performance of the decision tree model by using the confusion matrix.

cm_tree <- confusionMatrix(table(test_pred$pred, test_pred$Status))
tibble(
  'Accuracy'      = cm_tree$overall[[1]],
  'Sensitivity'   = cm_tree$byClass[[1]],
  'Specificity'   = cm_tree$byClass[[2]],
  'Positive term' = cm_tree$positive
)
## # A tibble: 1 × 4
##   Accuracy Sensitivity Specificity `Positive term`
##      <dbl>       <dbl>       <dbl> <chr>          
## 1    0.817       0.806       0.873 Not Severe
cm_tree
## Confusion Matrix and Statistics
## 
##             
##              Not Severe Severe
##   Not Severe     343306  10692
##   Severe          82588  73407
##                                          
##                Accuracy : 0.8171         
##                  95% CI : (0.816, 0.8182)
##     No Information Rate : 0.8351         
##     P-Value [Acc > NIR] : 1              
##                                          
##                   Kappa : 0.5055         
##                                          
##  Mcnemar's Test P-Value : <2e-16         
##                                          
##             Sensitivity : 0.8061         
##             Specificity : 0.8729         
##          Pos Pred Value : 0.9698         
##          Neg Pred Value : 0.4706         
##              Prevalence : 0.8351         
##          Detection Rate : 0.6732         
##    Detection Prevalence : 0.6941         
##       Balanced Accuracy : 0.8395         
##                                          
##        'Positive' Class : Not Severe     
## 

Better performance is achieved with the decision tree model compared to the previous two logistic regression models. Additionally, it takes much less time to train a decision tree model than to fit a logistic regression model.

5. Random forest

In general, decision tree models have a high accuracy on training data but a much lower accuracy on test data, which is the result of overfitting. Random forest can alleviate this overfitting effect by applying a special sampling technique called “bootstrapping.” By analyzing the final out-of-bag (OOB) error rate, a better model can be obtained.

Let’s build a random forest model using the randomForest package.

# Build random forest model with mtry = 6
set.seed(1515)
model_rf <- randomForest(
  Status ~ ., data = new_train,
  mtry = 6, ntree = 500,
  importance = T
)

Let’s tune mtry by searching for the optimal value (with respect to the OOB error estimate).

# Tune mtry using tuneRF
n <- length(new_train)
x <- new_train[, -n]
y <- new_train$Status
rf_mtry <- tuneRF(
  x, y,
  mtryStart = 4,
  ntreeTry = 500,
  stepFactor = 0.8,
  improve = 0.01,
  trace = TRUE,
  plot = TRUE
)

It’s appears that 19 is the optimal value for mtry. Let’s tune the number of trees (ntree).

# Manual tuning of ntree
set.seed(1515)
tuning <- vector(length = 4)
j <- 1
for (i in c(200, 500, 1000, 2000)) {
  rf_ntree <- randomForest(
    Status ~ ., data = new_train,
    mtry = 19, ntree = i
  )
  conf <- rf_ntree$confusion[,-ncol(rf_ntree$confusion)]
  oob <- 1 - (sum(diag(conf))/sum(conf))
  tuning[j] <- oob
  j <- j + 1
}

plot(
  x = c(200, 500, 1000, 2000),
  y = tuning,
  xlab = 'ntree value', ylab = 'OOB Error',
  type = 'b'
)
dev.off()

Let’s build a random forest model using mtry = 19 and ntree = 500.

# Build random forest model with mtry = 19
model_rf <- randomForest(
  Status ~ ., data = new_train,
  mtry = 19, ntree = 500,
  importance = T
)

Let’s make predictions on the test data using the random forest model.

# Make predictions on test data
test_pred <- model_rf %>%
  predict(test_set, type = 'class') %>%
  data.frame() %>%
  bind_cols(test_set$Status) %>%
  rename(
    pred = 1,
    Status = 2
  ) %>%
  dplyr::select(Status, pred)

Let’s review the performance of the random forest model by using the confusion matrix.

cm_rf <- confusionMatrix(table(test_pred$pred, test_pred$Status))
tibble(
  'Accuracy'      = cm_rf$overall[[1]],
  'Sensitivity'   = cm_rf$byClass[[1]],
  'Specificity'   = cm_rf$byClass[[2]],
  'Positive term' = cm_rf$positive
)
## # A tibble: 1 × 4
##   Accuracy Sensitivity Specificity `Positive term`
##      <dbl>       <dbl>       <dbl> <chr>          
## 1    0.837       0.821       0.917 Not Severe
cm_rf
## Confusion Matrix and Statistics
## 
##             
##              Not Severe Severe
##   Not Severe     349617   6965
##   Severe          76277  77134
##                                           
##                Accuracy : 0.8368          
##                  95% CI : (0.8358, 0.8378)
##     No Information Rate : 0.8351          
##     P-Value [Acc > NIR] : 0.0006062       
##                                           
##                   Kappa : 0.5547          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.8209          
##             Specificity : 0.9172          
##          Pos Pred Value : 0.9805          
##          Neg Pred Value : 0.5028          
##              Prevalence : 0.8351          
##          Detection Rate : 0.6855          
##    Detection Prevalence : 0.6992          
##       Balanced Accuracy : 0.8690          
##                                           
##        'Positive' Class : Not Severe      
## 

Summary

The performance of each predictive model is shown below. The random forest model improves accuracy compared to the decision tree model. However, the computation overhead by training and finding the best random forest model is significantly greater than training a decision tree model.

# Combine summary of results obtained for each model
rbind(
  'Logistic Regression' = data.frame(
    'Accuracy'          = cm_lr$overall[[1]],
    'Sensitivity'       = cm_lr$byClass[[1]],
    'Specificity'       = cm_lr$byClass[[2]],
    'Positive term'     = cm_lr$positive
  ),
  'LASSO Regression' = data.frame(
    'Accuracy'       = cm_lasso$overall[[1]],
    'Sensitivity'    = cm_lasso$byClass[[1]],
    'Specificity'    = cm_lasso$byClass[[2]],
    'Positive term'  = cm_lasso$positive
  ),
  'Decision Tree'   = data.frame(
    'Accuracy'      = cm_tree$overall[[1]],
    'Sensitivity'   = cm_tree$byClass[[1]],
    'Specificity'   = cm_tree$byClass[[2]],
    'Positive term' = cm_tree$positive
  ),
  'Random Forest'   = data.frame(
    'Accuracy'      = cm_rf$overall[[1]],
    'Sensitivity'   = cm_rf$byClass[[1]],
    'Specificity'   = cm_rf$byClass[[2]],
    'Positive term' = cm_rf$positive
  )
) %>%
mutate(across(where(is.numeric), \(x) round(x, 4)))
##                     Accuracy Sensitivity Specificity Positive.term
## Logistic Regression   0.7274      0.7232      0.7483    Not Severe
## LASSO Regression      0.7272      0.7231      0.7480    Not Severe
## Decision Tree         0.8171      0.8061      0.8729    Not Severe
## Random Forest         0.8368      0.8209      0.9172    Not Severe