#' @title Create simulated multistateQTL data for testing purposes
#'
#' @param nTests number of QTL tests
#' @param nStates number of states
#' @param nFeatures number of QTL features to simulate tests for, NULL mean
#'   nFeatures = nTests.
#' @param params list of parameters required to simulate betas and beta errors.
#'   Generated by `qtleEstimate()` or `qtleParams()`.
#' @param global percent of QTL tests with significant effects shared across all
#'   states
#' @param multi percent of QTL tests with significant effects shared across a
#'   subset of states.
#' @param k number of multi-state clusters or an array with the cluster
#'   assignments.
#' @param unique percent of QTL tests with significant effects in only one state
#' @param betaSd The desired standard deviation or an array of standard deviations
#'   equal to the length of states for sampling beta values for each state.
#' @param lfsr Logical to calculate lfsr using mashr_1by1.
#' @param verbose Logical.
#' 
#' @return A simulated `QTLExperiment` object.
#'
#' @details The simulation consists of user defined number of equal numbers of four different
#' types of effects: null, equal among conditions, present only in
#' first condition, independent across conditions
#'
#' @examples
#'
#' qtleSimulate(nTests=100, nStates=5, global=0.1, multi=0.2, unique=0.05)
#'
#' @importFrom stats rnorm
#' @importFrom mashr mash_1by1 mash_set_data
#'
#' @name qtleSimulate
#' @rdname qtleSimulate
#'
#' @export
#'

qtleSimulate <- function(params=qtleParams(), nTests=100, nFeatures = NULL,
    nStates = 5, global=0.5, multi=0, unique=0,
    k = 2, betaSd=0.1, lfsr=TRUE, verbose=TRUE){
    
    if ( !is(params, "list") )
        stop("params must be a list with names as described in ?qtleParams")

    if(is.null(nFeatures)){ nFeatures <- nTests }

    features <- paste0("F", formatC(seq_len(nFeatures), width=nchar(nFeatures), flag="0"))
    features <- sample(features, nTests, replace=TRUE)
    states <- paste0("S", formatC(seq_len(nStates), width=nchar(nStates), flag="0"))

    nGlobal <- floor(global * nTests)
    nMulti <- floor(multi * nTests)
    nUnique <- floor(unique * nTests)
    nNull <- nTests - nGlobal - nMulti - nUnique

    if(nNull < 0){stop("global + multi + unique must equal <= 1")}

    if(verbose){
        message("Simulating: ",
            ifelse(nGlobal > 0, sprintf('\n  %d global QTL', nGlobal), ""),
            ifelse(nMulti > 0,
                sprintf('\n  %d multi-state QTL with %d different patterns',
                    nMulti, k), ""),
            ifelse(nUnique > 0, sprintf('\n  %d unique QTL', nUnique), ""),
            ifelse(nNull > 0, sprintf('\n  %d tests with no QTL', nNull), ""))
    }

    types <- c(rep("global", nGlobal), rep(.multistate, nMulti),
        rep("unique", nUnique), rep("null", nNull))

    key <- make.simulation.key(params, features, states, types, k, verbose)

    sim_betas <- simulateBetas(key, params, states, betaSd, verbose)
    sim_error <- simulateErrors(key, params, sim_betas, verbose)

    sim <- QTLExperiment(
        assay = list(
            betas=sim_betas,
            errors=sim_error),
        rowData = key)

    if(multi > 0){
        sim <- annotate.multistate.groups(sim)
    }

    if(lfsr){
        assay(sim, "lfsrs") <- mash_1by1(
            mash_set_data(betas(sim),errors(sim)))$result$lfsr
    }

    return(sim)
}


#' Generate simulation key
#'
#' @param params list of parameters required to simulate betas and beta errors.
#'   Generated by `qtleEstimate()` or `qtleParams()`.
#' @param features array of .feature_ids
#' @param states array of state names
#' @param types array of test types to include in simulation
#' @param k number of multi-state clusters
#' @param verbose Logical.
#'
#' @importFrom stats rgamma
#' 
#' @noRd
make.simulation.key <- function(params, features, states, types, k, verbose){

    key <- as.data.frame(list(
        feature_id=features,
        variant_id=paste0("v", sample(seq(1e3:1e5), length(features))),
        QTL=sample(types, length(types), replace=FALSE)))
    key[, "id"] <- paste(key$feature_id, key$variant_id, sep="|")
    key[, "mean_beta"] <- rgamma(length(features), params$betas.sig.shape,
        params$betas.sig.rate)
    key[key$QTL == "null", "mean_beta"] <- 0

    # Randomly make half of betas negative
    key[, "mean_beta"] <- key[, "mean_beta"] * sample(c(1, -1), length(features),
        replace = TRUE)

    key[, states] <- FALSE

    # Global effects: set TRUE for all states
    key[key$QTL == "global", states] <- TRUE

    # Unique effects: set TRUE for only one random state
    for(i in seq_len(nrow(key))){
        if(key[i, "QTL"] == "unique"){
            key[i, states] <- sample(c(TRUE, rep(FALSE, length(states)-1)))
        }
    }

    # Multi-state effects: define clusters and set TRUE for one cluster
    if("multistate" %in% types){
        key <- make.ms.simulation.key(key, states, k, verbose)
    }

    return(key)

}


#' Add multi-state clusters to simulation key
#'
#' @param key matrix with simulation information stored.
#' @param states array of state names
#' @param k number of multi-state clusters
#' @param verbose Logical
#'
#' @importFrom stats rgamma setNames
#' 
#' @noRd

make.ms.simulation.key <- function(key, states, k, verbose){

    # Checks
    if(length(k) == 1){
        if(k < 2){
            warning("k == 1 results in global, not multi-state effects.")
        } else if (k >= length(states)){
            warning("k >= length(states) results in unique, not multi-state effects.")
        }
    } else{
        if(length(k) != length(states)){
            warning("k = # of desired clusters or an array with length = nStates.")
        }
    }

    # Assign states to clusters - ensure each cluster has at least one state
    if(length(k) == 1){
        ms.clusters <- setNames(c(seq_len(k), sample(seq_len(k), length(states)-k, replace=TRUE)),
                                states)
    } else{
        ms.clusters <- setNames(k, states)
    }
    ms.clusters <- paste0("Group", ms.clusters)
    k <- length(unique(ms.clusters))

    # Assign multi-state tests to multistateGroup
    n_ms <- sum(key$QTL == "multistate")
    key[, "multistateGroup"] <- NA
    key[key$QTL == "multistate", "multistateGroup"] <- 
        paste0("Group", sample(seq_len(k), n_ms, replace=TRUE))

    if(verbose){ message(
        "multistateGroup sizes: ", paste(table(ms.clusters), collapse=", "))}

    # Fill in key
    for (ki in paste0("Group", seq_len(k))){
        pattern <- ms.clusters == ki
        ki.which <- !is.na(key$multistateGroup) & key$multistateGroup == ki
        key[ki.which, states] <- rep(pattern , each = sum(ki.which))
    }

    return(key)
}


#' Simulate beta values
#'
#' @param key data.frame containing information about how to simulate
#'   multistateQTL. Generated by `make.simulation.key()`.
#' @param states array of state names
#' @param params list of parameters required to simulate betas and beta errors.
#'   Generated by `qtleEstimate()` or `qtleParams()`.
#' @param betaSd The desired standard deviation or an array of sd equal to the
#'   length of states for sampling beta values for each state.
#' @param verbose Logical.
#'
#' @noRd
simulateBetas <- function(key, params, states, betaSd, verbose){

    if(!length(betaSd %in% c(1, length(states)))){
        warning("betaSd should be length 1 or equal to the number of states.")
    }

    # Simulate null betas for all tests
    sim_betas <- t(data.frame(lapply(key[, "mean_beta"],
                                     FUN = function(x) rnorm(length(states), x, betaSd))))
    colnames(sim_betas) <- states
    rownames(sim_betas) <- key$id

    # Simulate null betas for all tests with a FALSE QTL effect
    nNull <- sum(key[, states]==FALSE)
    null_betas <- rgamma(nNull,
        shape=params$betas.null.shape,
        params$betas.null.rate)
    null_betas <- null_betas * sample(c(1, -1), length(null_betas), replace=TRUE)

    sim_betas[key[, states]==FALSE] <- null_betas

    return(sim_betas)
}


#' Simulate error values
#'
#' @param key data.frame containing information about how to simulate
#'   multistateQTL. Generated by `make.simulation.key()`.
#' @param sim_betas Output from `simulateBetas()`.
#' @param params list of parameters required to simulate betas and beta errors.
#'   Generated by `qtleEstimate()` or `qtleParams()`.
#' @param verbose Logical.
#' 
#' @noRd
simulateErrors <- function(key, params, sim_betas, verbose){

    states <- colnames(sim_betas)
    sim_cvs <- matrix(
        rgamma(nrow(sim_betas)*ncol(sim_betas),
            shape=params$cv.sig.shape,
            rate=params$cv.sig.rate),
        nrow=nrow(sim_betas), ncol=ncol(sim_betas))

    # Simulate null betas for all tests with a FALSE QTL effect
    nNull <- sum(key[, states]==FALSE)
    sim_null_cvs <- rgamma(
        nNull, shape=params$cv.null.shape, params$cv.null.rate)


    sim_cvs[key[, states]==FALSE] <- sim_null_cvs

    sim_error <- sim_cvs * abs(sim_betas)

    return(sim_error)
}


#' Annotate the colData with the multistate group
#'
#' @param sim Simulated QTLExperiment object to annotate
#'
#' @importFrom dplyr select filter %>% distinct starts_with arrange
#' @importFrom collapse na_omit
#' @importFrom SummarizedExperiment colData<-
#' 
#' @noRd

annotate.multistate.groups <- function(sim){

    clusters <- as.data.frame(rowData(sim)) %>%
        distinct(multistateGroup, .keep_all = TRUE) %>%
        dplyr::select(starts_with("S"), multistateGroup) %>%
        na_omit() %>%
        pivot_longer(-multistateGroup) %>%
        dplyr::filter(value == TRUE) %>%
        arrange(name)

    colData(sim)[, "multistateGroup"] <- clusters[, "multistateGroup"]

    return(sim)

}

