#' Projekt: Testings
#' Author: Paul Ritsche
#' Title: Muscle Architecture & Strength
#' last edited: 19.12.2023

library(dplyr)
library(readxl)
library(ggplot2)
library(ggside)
library(ggstatsplot)
library(NormalityAssessment)
library(easystats)
library(see)
library(correlation)
library(GGally)
library(ggExtra)
library(reshape2)
library(MASS)
library(car)
library(lmtest)
library(ppcor)
library(openxlsx)
library(viridis)
library(tidyr)


######################################################
##### Define functions required for the analysis #####
######################################################

#' Filter and Transform Data Columns
#'
#' This function filters the input dataset into three subsets: left, right, and dom. 
#' It also converts all columns in each subset into numeric type.
#'
#' @param data A data frame containing various columns related to "Körpergrösse_cm", "SMM_kg", 
#' "isom_ext_li", "sum_acsa_33_li", "isom_ext_re", "sum_acsa_33_re", "dom_isom_ext", and "dom_sum_acsa_33".
#'
#' @return A list containing three data frames:
#' \itemize{
#'   \item \strong{left}: Filtered data with columns from "Körpergrösse_cm" to "SMM_kg", 
#'   and "isom_ext_li" to "sum_acsa_33_li", all converted to numeric.
#'   
#'   \item \strong{right}: Filtered data with columns from "Körpergrösse_cm" to "SMM_kg", 
#'   and "isom_ext_re" to "sum_acsa_33_re", all converted to numeric.
#'
#'   \item \strong{dom}: Filtered data with columns from "Körpergrösse_cm" to "SMM_kg", 
#'   and "dom_isom_ext" to "dom_sum_acsa_33".
#' }
#'
#' @examples
#' \dontrun{
#' data <- data.frame(
#'   Körpergrösse_cm = c(170, 175),
#'   SMM_kg = c(60, 65),
#'   isom_ext_li = c(10, 12),
#'   sum_acsa_33_li = c(20, 22),
#'   isom_ext_re = c(11, 13),
#'   sum_acsa_33_re = c(21, 23),
#'   dom_isom_ext = c(15, 17),
#'   dom_sum_acsa_33 = c(25, 27)
#' )
#'
#' result <- filter_data(data)
#' }
#'
#' @importFrom dplyr select mutate
#' @export
filter_data <- function(data){

  left <- data %>%
    dplyr::select("Alter":"SMM_kg", "isom_ext_li":"sum_acsa_33_li") %>% 
    dplyr::mutate(across(everything(), ~as.numeric(.)))

  right <- data %>% 
    dplyr::select("Alter":"SMM_kg", "isom_ext_re":"sum_acsa_33_re") %>% 
    dplyr::mutate(across(everything(), ~as.numeric(.)))
  
  dom <- data %>% 
     dplyr::select("Alter":"SMM_kg", "dom_isom_ext":"dom_sum_acsa_33")

  return(list(right, left, dom))

}

#' Plot Pairwise Correlations and Correlation Matrix
#'
#' This function takes in a data frame and generates pairwise correlation plots 
#' and a correlation matrix visualized with color. The function allows flexibility 
#' in the correlation method used and offers Bayesian correlation as an option.
#'
#' @param data A data frame containing numeric columns to be correlated.
#' @param cor_meth A character string specifying the correlation method. Default is "spearman".
#' @param print_pairs A logical value indicating whether to print pairwise correlation plots. Default is FALSE.
#' @param bayes A logical value indicating whether to use Bayesian correlation. Default is FALSE.
#' @param bayes_test A character string specifying the Bayesian test to be used. Default is "pd".
#'
#' @return Invisible NULL. The function is mainly used for its side effect of plotting.
#'
#' @examples
#' \dontrun{
#' # Sample data frame
#' data <- data.frame(
#'   x1 = rnorm(100),
#'   x2 = rnorm(100),
#'   x3 = rnorm(100)
#' )
#'
#' # Plot correlations using default parameters
#' plot_whole_df(data)
#'
#' # Plot with Bayesian correlation
#' plot_whole_df(data, bayes = TRUE)
#' }
#'
#' @importFrom correlation correlation visualisation_recipe 
#' @export
plot_whole_df <- function(data, cor_meth = "spearman", print_pairs = FALSE, bayes = FALSE, bayes_test = "pd"){
  # Calculate correlation pair with switched axis labels
  if(print_pairs == TRUE){

    # Make pairwise correlatoins
    pairs <- ggpairs(data) +
      theme(axis.text.y.left=element_text(angle=0,hjust=1,vjust=0.5),
            axis.text.x.bottom=element_text(angle=90,hjust=1,vjust=0.5),
            axis.text.y.right=element_text(angle=0,hjust=1,vjust=0.5),
            axis.text.x.top=element_text(angle=90,hjust=1,vjust=0.5)
      )
    # plot pairwise correlations
    print(pairs)

  }
  # Calculate correlation matrix using
  cor <- correlation::correlation(data, method = cor_meth, bayesian = bayes,
                     bayesian_test = bayes_test)
  sum_cor <- summary(cor)

  # Make layers for correlation plot
  layers <- correlation::visualisation_recipe(sum_cor,
                                 show_data = "points",
                                 scale = list(range = c(10, 20)),
                                 scale_fill = list(
                                   high = "#FF5722",
                                   low = "#673AB7"
                                 ),

                                 text = FALSE,
                                 labs = list(title = "")
  )

  # Save ggplot variable$
  plot <-  plot(layers) + theme_modern() +
    theme(axis.text.x=element_text(angle=90,hjust=1,vjust=0.5))

  # Show plot (print required to show)
  print(plot)

}

#' Perform t-tests on paired lists of tibbles and export results to a CSV file.
#'
#' This function takes two lists of tibbles, performs t-tests on paired columns 
#' from the tibbles, and writes the results, including mean differences and 
#' confidence intervals, to a CSV file.
#'
#' @param data1 A list of tibbles.
#' @param data2 A list of tibbles of the same length as data1.
#'
#' @return A CSV file named "ttest_values.csv" in the working directory.
#'
#' @examples
#' ttest(list1, list2)
#'
#' @export
ttest <- function(data1, data2){
  
  # Add header
  printHeader = TRUE
  
  # loop over list
  for (list in 1:length(data1)){
    
    # determine column names
    colnames <- colnames(data1[[list]])
    
    # loop over columns
    for (col in 1:ncol(data1[[list]])) {
      
      # calculate ttest
      ttest <- t.test(data1[[list]][col], data2[[list]][col], var.equal=FALSE, paired=FALSE)
      # calculate mean 1
      M1 <- ttest$estimate[1]
      #calculate SD 1
      SD1 <- sd(as.data.frame(data1[[list]])[,col], na.r = TRUE)
      # calculate mean2
      M2 <- ttest$estimate[2]
      # calcutale SD 2
      SD2 <- sd(as.data.frame(data2[[list]])[,col], na.rm = TRUE)
      
      # calculate mean difference
      MD <- ttest$estimate[1] - ttest$estimate[2]
      MD_CI <- ttest$conf.int
      # calculate percentage difference
      MDP <- (MD/((M1+M2)/2)) * 100
      # calculate smd 
      smd <- cohens_d(x=data1[[list]][[col]], y=data2[[list]][[col]], pooled_sd = FALSE)
      

      comp_vals <- data.frame("Comparison" = paste(colnames[col]),
                              "Mean group1 (SD)" = paste(format(round(M1, digits = 1),nsmall = 2)," (",
                                                         format(round(SD1, digits = 1)), ")", sep = ""),
                              "Mean Group2 (SD)" = paste(format(round(M2, digits = 1),nsmall = 2)," (",
                                                         format(round(SD2, digits = 1)), ")", sep = ""),
                              "MD (95% CI)" = paste(format(round(MD, digits = 1),nsmall = 2)," (",
                                           format(round(MD_CI[1], digits = 1),nsmall = 2),", ",
                                           format(round(MD_CI[2], digits = 1),nsmall = 2),")", sep = ""),
                              "MDP %" = round(MDP, 1),
                              "SMD (95%CI)" = paste(format(round(smd$Cohens_d, digits = 1),nsmall = 2)," (",
                                                    format(round(smd$CI_low, digits = 1),nsmall = 2),", ",
                                                    format(round(smd$CI_high, digits = 1),nsmall = 2),")", sep = ""))
      
      if(printHeader){
        write.table(comp_vals, file = "ttest_values.csv", append=TRUE,
                    col.names = TRUE, row.names = FALSE)
        printHeader=FALSE
      } else {
        write.table(comp_vals, file = "ttest_values.csv", append = TRUE,
                    col.names = FALSE, row.names = FALSE)
      }
  
    }
  }
}

#' Calculate and Format Correlations
#'
#' This function calculates correlations for given data columns and returns
#' the correlation coefficients and 95% confidence intervals in a formatted string.
#'
#' @param data A data frame containing the variables to be correlated.
#' @param select_cols A vector of column names in `data` for the first set of variables.
#' @param select2_cols A vector of column names in `data` for the second set of variables.
#' @param confounders A vectors of column names in `data` acting as confounders for the relation of the above specified variables
#'
#' @return A formatted string of the form: "correlation coefficient (lower 95% CI, upper 95% CI)".
#' @examples
#' \dontrun{
#' data <- data.frame(x1 = rnorm(100), x2 = rnorm(100), x3 = rnorm(100), x4 = rnorm(100))
#' calculate_correlation(data, c("x1", "x2"), c("x3"), c("x4"))
#' }
calculate_correlation <- function(data, select_cols, select2_cols, confounders){
  # define empty arrays
  formatted_para <- c()
  formatted_cor <- c()
  for (col in select_cols){

    for (col2 in select2_cols){
      # Calc partial correlation
      cor_data <- correlation::correlation(data, select = col, select2 = c(col2, confounders), partial = TRUE, method = "Spearman")
      # Extract relevant parameters
      param1  <- cor_data$Parameter1
      param2 <- cor_data$Parameter2
      r <- cor_data$r
      CI_low <- cor_data$CI_low
      CI_high <- cor_data$CI_high

      if (param2[1] == col2){
        formatted_para <- c(formatted_para, sprintf("%s vs %s", param1[1], param2[1]))
        formatted_cor <- c(formatted_cor, sprintf("%.2f (%.2f, %.2f)", r[1], CI_low[1], CI_high[1]))
      }
      else if (param2[2] == col2){
        formatted_para <- c(formatted_para, sprintf("%s vs %s", param1[2], param2[2]))
        formatted_cor <- c(formatted_cor, sprintf("%.2f (%.2f, %.2f)", r[2], CI_low[2], CI_high[2]))
      }
      else if (param2[3] == col2){
        formatted_para <- c(formatted_para, sprintf("%s vs %s", param1[3], param2[3]))
        formatted_cor <- c(formatted_cor, sprintf("%.2f (%.2f, %.2f)", r[3], CI_low[3], CI_high[3]))
      }
      else if (param2[4] == col2){
        formatted_para <- c(formatted_para, sprintf("%s vs %s", param1[4], param2[4]))
        formatted_cor <- c(formatted_cor, sprintf("%.2f (%.2f, %.2f)", r[4], CI_low[4], CI_high[4]))
      }
      else{
        warning("More than 3 confounders used. Computations too excessive.")
      }
    }
    
  }

  return(list(formatted_para, formatted_cor))
}

#' Perform Regression Diagnostics
#'
#' This function takes a dataset and two lists of variable names (one for dependent variables and the other for independent variables). 
#' For each pair of dependent and independent variable, it fits a linear model and then performs diagnostic checks.
#'
#' @param data A dataframe containing the variables of interest.
#' @param dependent_list A character vector listing the names of dependent variables.
#' @param independent_list A character vector listing the names of independent variables.
#' 
#' @return A list containing the regression models and the results of the Breusch-Pagan and Harvey-Collier tests.
#' 
#' @examples
#' \dontrun{
#' results <- perform_regression_diagnostics(mydata, c("y1", "y2"), c("x1", "x2", "x3"))
#' }
#' 
#' @importFrom car crPlots
#' @importFrom lmtest bptest harvtest
perform_regression_diagnostics <- function(data, dependent_list, independent_list) {
  
  results <- list()
  
  for (dep_var in dependent_list) {
    
    cat("\n------------------------------------\n")
    cat("Dependent Variable:", dep_var) 
    cat("\n------------------------------------\n")
    
    for (ind_var in independent_list) {
      
      # Handle variable names that start with a number
      safe_dep_var <- ifelse(stringr::str_detect(dep_var, "^[0-9]"), paste0("`", dep_var, "`"), dep_var)
      safe_ind_var <- ifelse(stringr::str_detect(ind_var, "^[0-9]"), paste0("`", ind_var, "`"), ind_var)
      
      formula_str <- paste(safe_dep_var, "~", safe_ind_var)
      model <- lm(as.formula(formula_str), data=data)
      
      # Breusch-Pagan Test
      # Ho = homoscedasticity
      bp_test <- lmtest::bptest(model)
      
      # Harvey-Collier test for linearity
      # H0 = linearity
      hc_test <- lmtest::harvtest(model)
      
      if (bp_test$p.value < 0.05 || hc_test$p.value < 0.05){
        
        cat("Regression model for:", formula_str, "\n")
        
        # Component + Residual plot
        car::crPlots(model, main = formula_str)
        
        cat("Breusch-Pagan Test:\n")
        print(bp_test)
        
        cat("Harvey-Collier Test:\n")
        print(hc_test)
        
        results[[paste(dep_var, ind_var, sep="_")]] <- list(
          model = model,
          bp_test = bp_test,
          hc_test = hc_test
        )
        
        cat("\n------------------------------------\n")
      }
    }
  }
  
  return(results)
}

##########################
##### START ANALYSIS #####
##########################

##### LOAD DATA #####
# load data
setwd(".../data")
data1 <- read_excel("2021_U15-21.xlsx")
View(data1)


# Set filter criteria, it does not matter what this is as long a factor
teams <- c("U15", "U16", "U17", "U18", "U21")

# Get single teams form data
filtered_teams <- lapply(teams, function(team) {
  data1 %>% filter(Team == team)
})

# If you want to write each data frame to a CSV, you can do:
lapply(seq_along(teams), function(i) {
  write.csv(filtered_teams[[i]], paste0("data1_", teams[i], ".csv"), row.names = FALSE)
})

# Filter data
U15 <- filter_data(filtered_teams[[1]])
U16 <- filter_data(filtered_teams[[2]])
U17 <- filter_data(filtered_teams[[3]])
U18 <- filter_data(filtered_teams[[4]])
U21 <- filter_data(filtered_teams[[5]])
all <- filter_data(data1)


##### NORMALITY ASSESSMENT #####
# make teamspecific datasets for normality assessment app
# these need to be saved in separate files for loading them into
# the app. This is only necessary once.
cols <- colnames(filtered_teams[[1]])
cols <- cols[10:length(cols)]

# If p<0.05 data not normally distributed, H0 = data is normaly
normality <- lapply(filtered_teams, function(df) {
  sapply(df[, cols], function(x) {
    test <- shapiro.test(x)
    # Return TRUE if p-value < 0.05 (not normally distributed)
    return(test$p.value < 0.05)
  })
})

# Printing the non-normal columns for each team
for (team in seq_along(teams)) {
  non_normal_cols <- cols[which(normality[[team]])]
  if (length(non_normal_cols) > 0) {
    cat(paste0("Team ", teams[team], " non-normal columns: "), paste(non_normal_cols, collapse = ", "), "\n")
  } else {
    cat(paste0("Team ", teams[team], " has all columns normally distributed."), "\n")
  }
}

# test normality and visulize data using shiny app
runNormalityAssessmentApp()


##### Diagnostics ######
# check assumptions of heteroskedastisticy and linearity
# This is the same as for linear regression
perform_regression_diagnostics(U18[[2]], c("isom_ext_li", "X60_ext_li",	"X240_ext_li",
                                                                "rel_isom_ext_li", "rel_60_ext_li", "rel_240_ext_li"),
                                                    c("VL_PA_li_66","VL_PA_li_50", "VL_PA_li_33",
                                                      "VL_FL_li_66","VL_FL_li_50", "VL_FL_li_33",
                                                      "VL_MT_li_66","VL_MT_li_50", "VL_MT_li_33",
                                                      "acsa_66_li_rf", "acsa_50_li_rf", "acsa_33_li_rf", 
                                                      "acsa_66_li_vl", "acsa_50_li_vl", "acsa_33_li_vl",
                                                      "sum_acsa_66_li", "sum_acsa_50_li", "sum_acsa_33_li"))

##### Participant characteristics #####
# Columns to calculate mean and SD for
cols <- c("Alter", "MatRatio", "Körpergrösse_cm", "Gewicht_kg", "SMM_kg", "Sitzhöhe_cm", "Schritthöhe_cm")

# Calculate means and SDs for each team, with na.rm=TRUE
means_list <- lapply(filtered_teams, function(df) sapply(df[, cols], function(x) mean(x, na.rm = TRUE)))
sds_list <- lapply(filtered_teams, function(df) sapply(df[, cols], function(x) sd(x, na.rm = TRUE)))

# Print results
for (team in seq_along(teams)) {
  cat(paste0("\n", "Team: ", teams[team], "\n"))
  
  for (i in seq_along(cols)) {
    cat(paste0("Column ", cols[i], ": \n"))
    cat(paste0("\tMean: ", round(means_list[[team]][i], 2), "\n"))
    cat(paste0("\tSD: ", round(sds_list[[team]][i], 2), "\n"))
  }
}





##### ASSOCIATION ANALYSIS #####

# Analysis all
# Applying the correlation function for various groups
results_li <- calculate_correlation(U15[[2]], select_cols =c("isom_ext_li", "X60_ext_li",	"X240_ext_li",
                                                             "rel_isom_ext_li", "rel_60_ext_li", "rel_240_ext_li"),
                                              select2_cols = c("VL_PA_li_66","VL_PA_li_50", "VL_PA_li_33",
                                                               "VL_FL_li_66","VL_FL_li_50", "VL_FL_li_33",
                                                               "VL_MT_li_66","VL_MT_li_50", "VL_MT_li_33",
                                                               "acsa_66_li_rf", "acsa_50_li_rf", "acsa_33_li_rf", 
                                                               "acsa_66_li_vl", "acsa_50_li_vl", "acsa_33_li_vl",
                                                               "sum_acsa_66_li", "sum_acsa_50_li", "sum_acsa_33_li"),
                                    confounders = c("Gewicht_kg", "Körpergrösse_cm", "MatRatio"))

results_re <- calculate_correlation(U15[[1]], select_cols =c("isom_ext_re", "X60_ext_re",	"X240_ext_re",
                                                             "rel_isom_ext_re", "rel_60_ext_re", "rel_240_ext_re"),
                                    select2_cols = c("VL_PA_re_66","VL_PA_re_50", "VL_PA_re_33",
                                                     "VL_FL_re_66","VL_FL_re_50", "VL_FL_re_33",
                                                     "VL_MT_re_66","VL_MT_re_50", "VL_MT_re_33",
                                                     "acsa_66_re_rf", "acsa_50_re_rf", "acsa_33_re_rf", 
                                                     "acsa_66_re_vl", "acsa_50_re_vl", "acsa_33_re_vl",
                                                     "sum_acsa_66_re", "sum_acsa_50_re", "sum_acsa_33_re"),
                                    confounders = c("Gewicht_kg", "Körpergrösse_cm", "MatRatio"))


results_dom <- calculate_correlation(U15[[3]], select_cols =c("dom_isom_ext", "dom_60_ext",	"dom_240_ext",
                                                             "rel_dom_isom_ext", "rel_dom_60_ext", "rel_dom_240_ext"),
                                    select2_cols = c("dom_VL_PA_66","dom_VL_PA_50", "dom_VL_PA_33",
                                                     "dom_VL_FL_66","dom_VL_FL_50", "dom_VL_FL_33",
                                                     "dom_VL_MT_66","dom_VL_MT_50", "dom_VL_MT_33",
                                                     "dom_acsa_66_rf", "dom_acsa_50_rf", "dom_acsa_33_rf", 
                                                     "dom_acsa_66_vl", "dom_acsa_50_vl", "dom_acsa_33_vl",
                                                     "dom_sum_acsa_66", "dom_sum_acsa_50", "dom_sum_acsa_33"),
                                    confounders = c("Gewicht_kg", "Körpergrösse_cm", "MatRatio"))



# Constructing the data frame
# Use the second list entry because in the first the correlation names are stored
comp_vals <- data.frame("parameter" = results_re[[1]], "right" = results_re[[2]], "left" = results_li[[2]], "dominant"= results_dom[[2]])

write.csv(as.data.frame(comp_vals), "correlations_U15_adj_all.csv")

##### GROUP DIFFERENCES #####

# Calculate group differences
setwd(".../data")
ttest(U18, U21)                          

##### Plotting #####

# Plot dataframe
plot_whole_df(U15[[1]])

##### HEATMAPS #####

# Load the data with specific correlations
data <- read.xlsx(".../data/correlations_heatmap.xlsx")

# Select strength task
strength_task <- 'isom_ext'

# Identify columns for correlation and CI width
correlation_cols <- grep("_right$|_left$", names(data), value = TRUE)
ci_width_cols <- grep("_width$", names(data), value = TRUE)

# Melt the data into a long format, keeping only relevant columns
data_long <- data %>%
  filter(strength == strength_task) %>%
  pivot_longer(cols = c(correlation_cols, ci_width_cols), names_to = "team_leg", values_to = "value") %>%
  mutate(team = gsub("(_right$|_left$|_right_width$|_left_width$)", "", team_leg),
         leg = ifelse(grepl("_right", team_leg), "right", "left"),
         measure = ifelse(grepl("_width", team_leg), "width", "correlation"),
         team_leg = interaction(team, leg, measure))

p <- ggplot(data_long) +
  geom_tile(data = data_long %>% filter(measure == "width"), 
            aes(x = team, y = morph, width = value), 
            fill = "grey", alpha = 0.5, color = "white", show.legend = FALSE, height=0.7, linewidth=2) +
  geom_tile(data = data_long %>% filter(measure == "correlation"), 
            aes(x = team, y = morph, fill = value, width=value), show.legend = TRUE, height=0.9) +
  # Add correlation values as text
  geom_text(data = data_long %>% filter(measure == "correlation"), 
            aes(x = team, y = morph, label = sprintf("%.2f", value)), 
            color = "black", size = 5, vjust = 0.5, fontface='bold') +
  scale_fill_viridis(name = "Correlation", direction = -1, option="plasma") +
  facet_wrap(~leg, scales = "free_x") +
  theme_modern() +
  labs(title = paste("Strength Task:", strength_task),
       x = "Team",
       y = "Morphologic Parameter",
  ) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1),
        axis.title = element_text(size = 20, face = "bold"),
        axis.text = element_text(size = 18),
        axis.text.x.top = element_text(size = 18),
        strip.text = element_text(size = 20, face = "bold"))

p


