#include <setjmp.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

#include "thread_tool.h"

void idle(int id, int *args) {
    (void)args;
    thread_setup(id, NULL);
    for (;;) {
        printf("thread %d: idle\n", current_thread->id);
        sleep(1);
        thread_yield();
    }
}

void fibonacci(int id, int *args) {
    thread_setup(id, args);

    debug_printf("thread %d: fib\n", current_thread->id);
    debug_printf("FIBONACCI: hi\n");
    debug_printf("FIBONACCI: args[0] = %d\n", current_thread->args[0]);
    current_thread->n = current_thread->args[0];
    debug_printf("FIBONACCI: n = %d\n", current_thread->n);
    for (current_thread->i = 1;; current_thread->i++) {
        if (current_thread->i <= 2) {
            current_thread->f_cur = 1;
            current_thread->f_prev = 1;
        } else {
            int f_next = current_thread->f_cur + current_thread->f_prev;
            current_thread->f_prev = current_thread->f_cur;
            current_thread->f_cur = f_next;
        }
        printf("thread %d: F_%d = %d\n", current_thread->id, current_thread->i,
               current_thread->f_cur);

        sleep(1);

        if (current_thread->i == current_thread->n) {
            thread_exit();
        } else {
            thread_yield();
        }
    }
}

void pm(int id, int *args) {
    thread_setup(id, args);
    debug_printf("PM: hi\n");
    current_thread->p_n = current_thread->args[0];
    for (current_thread->p_i = 1;; current_thread->p_i++) {
        if (current_thread->p_i == 1) {
            current_thread->p_cur = 1;
        } else {
            int one = ((current_thread->p_i - 1) % 2 == 0) ? 1 : -1;
            int f_next = one * current_thread->p_i + current_thread->p_cur;
            current_thread->p_cur = f_next;
        }
        printf("thread %d: pm(%d) = %d\n", current_thread->id, current_thread->p_i, current_thread->p_cur);
        sleep(1);

        if (current_thread->p_i == current_thread->p_n) {
            thread_exit();
        } else {
            thread_yield();
        }
    }
}

void enroll(int id, int *args) {
    thread_setup(id, args);
    debug_printf("ENROLL: hi\n");
    current_thread->dp = current_thread->args[0];
    current_thread->ds = current_thread->args[1];
    current_thread->s  = current_thread->args[2];
    current_thread->b  = current_thread->args[3];

    printf("thread %d: sleep %d\n", current_thread->id, current_thread->s);
    thread_sleep(current_thread->s);
    thread_awake(current_thread->b);

    read_lock();
    current_thread->cur_qp = q_p, current_thread->cur_qs = q_s;
    printf("thread %d: acquire read lock\n", current_thread->id);
    sleep(1);
    thread_yield();

    read_unlock();
    int pp = current_thread->dp * current_thread->cur_qp; // pj class
    int ps = current_thread->ds * current_thread->cur_qs; // sw class
    printf("thread %d: release read lock, p_p = %d, p_s = %d\n", current_thread->id, pp, ps);
    sleep(1);
    thread_yield();

    write_lock();
    int is_pj = (pp > ps) ? 1 : 0;
    if (pp == ps) is_pj = (current_thread->dp > current_thread->ds) ? 1 : 0; // break ties
    // if class is full, choose the other one
    if (is_pj && q_p == 0) is_pj = 0;
    else if (!is_pj && q_s == 0) is_pj = 1;
    printf("thread %d: acquire write lock, enroll in %s\n", current_thread->id, (is_pj) ? "pj_class" : "sw_class");
    if (is_pj) q_p--;
    else q_s--;
    sleep(1);
    thread_yield();

    write_unlock();
    printf("thread %d: release write lock\n", current_thread->id);
    sleep(1);
    thread_exit();
}

