## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(ggplot2)
library(dplyr)
library(tibble)
library(purrr)
library(patchwork)
library(masc)

## -----------------------------------------------------------------------------
# Function to simulate the MASC model behavior with various weight differences and noise levels
simulate_weight_attention_relationship <- function(
    n_trials = 100,                      # Number of trials per condition
    weight_diffs = seq(0, 1, by = 0.05), # Weight differences to test
    noise_levels = seq(0.5, 3, by = 0.5),# Noise levels to test
    max_fixations = 100,                 # Maximum number of fixations
    alpha = 10,                          # High search sensitivity (as in MATLAB)
    delta = 0.05                         # Threshold increase (as in MATLAB)
) {
  # Pre-allocate results data frames
  weight_att_results <- data.frame()
  fixation_development <- data.frame()

  # Loop through noise levels
  for (noise in noise_levels) {
    # Loop through weight differences
    for (w_diff in weight_diffs) {
      # Calculate weights - ensuring they're both positive and sum to 1
      # When w_diff = 0, weights are equal (0.5, 0.5)
      # When w_diff = 1, weights are (1, 0) - not allowed, so we'll use (0.99, 0.01)
      w1 <- 0.5 + w_diff/2
      w2 <- 1 - w1

      # Ensure weights are positive and not exactly 0 (use small positive value)
      if (w2 <= 0) {
        w1 <- 0.99
        w2 <- 0.01
      }
      weights <- c(w1, w2)

      # Verify weights sum to 1
      stopifnot(abs(sum(weights) - 1) < 1e-10)

      # Run simulation
      sim <- rMASC(
        n = n_trials,
        n_options = 2,
        n_attributes = 2,
        w = weights,
        sigma = noise,
        alpha = alpha,
        delta = delta,
        max_steps = max_fixations
      )

      # Calculate attention metrics for each trial
      trial_att_diffs <- map_dbl(sim$raw, function(trial) {
        fix_seq <- trial$fix_sequence

        # Count fixations to each attribute
        att_indices <- ceiling(fix_seq / 2)
        att1_fixes <- sum(att_indices == 1)
        att2_fixes <- sum(att_indices == 2)
        total_fixes <- length(fix_seq)

        # Calculate attention difference
        att_diff <- (att1_fixes/total_fixes) - (att2_fixes/total_fixes)
        return(att_diff)
      })

      # Add to weight-attention results
      weight_att_results <- bind_rows(
        weight_att_results,
        tibble(
          noise_level = noise,
          weight_diff = w_diff,
          attention_diff = mean(trial_att_diffs)
        )
      )

      # If weight difference is 0.5, extract fixation development data
      if (abs(w_diff - 0.5) < 0.001) {
        # Extract fixation development for each trial
        trial_fix_data <- map_dfr(sim$raw, function(trial) {
          fix_seq <- trial$fix_sequence
          max_fix <- min(length(fix_seq), max_fixations)

          # For each fixation position, calculate proportion to attribute 1
          fix_props <- map_dfr(1:max_fix, function(fix_num) {
            # Only use data up to current fixation
            curr_fixes <- fix_seq[1:fix_num]

            # Calculate proportion to attribute 1
            att_indices <- ceiling(curr_fixes / 2)
            att1_prop <- sum(att_indices == 1) / length(att_indices)

            tibble(
              trial = trial$trial,
              fixation_num = fix_num,
              att1_prop = att1_prop
            )
          })

          fix_props
        })

        # Average across trials for each fixation position
        fix_dev_data <- trial_fix_data %>%
          group_by(fixation_num) %>%
          summarize(att1_prop = mean(att1_prop)) %>%
          mutate(noise_level = noise)

        # Add to fixation development results
        fixation_development <- bind_rows(
          fixation_development,
          fix_dev_data
        )
      }
    }
  }

  # Return both datasets
  list(
    weight_att = weight_att_results,
    fix_dev = fixation_development
  )
}

# Create color palette (green gradient as in MATLAB)
create_color_palette <- function(noise_levels) {
  start_color <- c(194, 218, 184) / 255
  end_color <- c(1, 50, 32) / 255

  # Generate color gradient
  colors <- tibble(
    noise_level = noise_levels,
    r = seq(start_color[1], end_color[1], length.out = length(noise_levels)),
    g = seq(start_color[2], end_color[2], length.out = length(noise_levels)),
    b = seq(start_color[3], end_color[3], length.out = length(noise_levels))
  )

  # Convert to hex colors
  colors <- colors %>%
    mutate(hex = rgb(r, g, b))

  # Return as named vector
  setNames(colors$hex, colors$noise_level)
}

# Plot the results
plot_weight_attention_results <- function(results, colors) {
  # Panel A: Fixation development over time (when weight diff = 0.5)
  p1 <- results$fix_dev %>%
    filter(fixation_num <= 20) %>%  # Limit to first 12 fixations
    ggplot(aes(x = fixation_num, y = att1_prop, color = factor(noise_level))) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Sampling Noise") +
    labs(
      x = "Fixation Number",
      y = "p(Fix)_Most Important"
    ) +
    theme_classic() +
    theme(
      legend.position = "none",
      panel.grid.minor = element_blank()
    ) +
    ylim(0, 1)

  # Panel B: Weight difference vs attention difference
  p2 <- results$weight_att %>%
    ggplot(aes(x = weight_diff, y = attention_diff, color = factor(noise_level))) +
    geom_line() +
    geom_point() +
    geom_abline(intercept = 0, slope = 1, color = "gray70") +
    geom_vline(xintercept = 0.50, color = "gray70") +
    scale_color_manual(values = colors, name = "Sampling Noise") +
    labs(
      x = "Weight_Att1 - Weight_Att2",
      y = "p(Fix)_Att1 - p(Fix)_Att2"
    ) +
    theme_classic() +
    theme(
      legend.position = "bottom",
      panel.grid.minor = element_blank()
    ) +
    ylim(-0.01, 1.01) +
    xlim(-0.01, 1.01)

  # Combine plots
  combined_plot <- p1 + p2 +
    plot_layout(widths = c(1, 1)) +
    plot_annotation(
      title = "Attribute Weights and Attention with Varying Sampling Noise",
      subtitle = "MASC Model Simulation",
      theme = theme(
        plot.title = element_text(size = 16, face = "bold"),
        plot.subtitle = element_text(size = 12)
      )
    ) &
    theme(legend.position = "bottom")

  combined_plot
}

## -----------------------------------------------------------------------------
# Run the simulation
set.seed(2025)
noise_levels <- seq(0.5, 3, by = 0.5)
weight_diffs <- seq(0, 1, by = 0.125)

# Create color palette
color_palette <- create_color_palette(noise_levels)

# Run simulation (this may take some time)
results <- simulate_weight_attention_relationship(
  n_trials = 200,
  weight_diffs = weight_diffs,
  noise_levels = noise_levels,
  alpha = 10,  # High search sensitivity as in MATLAB
  delta = 0.05 # Same as MATLAB
)

## ----fig.width=12, fig.height=8, out.width="100%"-----------------------------
# Plot results
fig9_plot <- plot_weight_attention_results(results, color_palette)

# Display plot
print(fig9_plot)

