#include "tpool.h"

#include <pthread.h>
#include <stdlib.h>
#include <stdio.h>

// #define DEBUG

#ifdef DEBUG
#define debug_printf(...) fprintf(stderr, __VA_ARGS__)
#else // DEBUG
#define debug_printf(...) 
#endif

void* frontend_routine(void* args);
void* backend_routine(void* args);

Request* request_init(Matrix a, Matrix b, Matrix c, int num_works);
void requests_enq(Tpool* pool, Request* req);
Request* requests_deq(Tpool* pool);

Work *work_init(Request *request);
void works_enq(Tpool* pool, Work* work);
Work* works_deq(Tpool* pool);

Matrix transpose(Matrix mat, int n);

void* frontend_routine(void* arg) {
    debug_printf("[FRONTEND] hello\n");
    Tpool* pool = (Tpool*)arg;
    while (69) {
        pthread_mutex_lock(&pool->req_lock);
        while (!pool->done && pool->requests_head == NULL) {
            pthread_cond_wait(&pool->req_empty, &pool->req_lock);
        }

        if (pool->done) break;
        debug_printf("[FRONTEND] lock acquired\n");

        Request* request = requests_deq(pool);
        pthread_mutex_unlock(&pool->req_lock);
        // transpose right hand matrix
        transpose(request->right, pool->n);

        debug_printf("[FRONTEND] dividing work...\n");
        // divide works and put in worker_queue
        int n2 = pool->n * pool->n;
        int each = n2 / request->num_works;
        for (int i = 0; i < request->num_works; i++) {
            Work* work = work_init(request);
            work->start = i * each;
            work->end = work->start + each;
            if (i == request->num_works - 1) work->end = n2;
            pthread_mutex_lock(&pool->work_lock);
            works_enq(pool, work);
            pthread_cond_signal(&pool->work_empty);
            pthread_mutex_unlock(&pool->work_lock);
        }
        // request has done its work
        free(request);
    }
    debug_printf("[FRONTEND] cleanup unlocking...\n");
    pthread_mutex_unlock(&pool->req_lock);
    debug_printf("[FRONTEND] bye\n");
    pthread_exit(NULL);
}

void* backend_routine(void* args) {
    pthread_t tid = pthread_self() % 1021;
    debug_printf("[BACKEND:%ld] hello\n", tid);
    Tpool* pool = (Tpool*)args;
    while (69) {
        debug_printf("[BACKEND:%ld] lock acquiring...\n", tid);
        pthread_mutex_lock(&pool->work_lock);
        while (!pool->done && pool->workers_head == NULL) {
            pthread_cond_wait(&pool->work_empty, &pool->work_lock);
        }

        if (pool->done) break;
        debug_printf("[BACKEND:%ld] lock acquired\n", tid);
        Work* work = works_deq(pool);
        pthread_mutex_unlock(&pool->work_lock);

        debug_printf("[BACKEND:%ld] n = %d, start = %d, end = %d\n", tid, pool->n, work->start, work->end);
        for (int pos = work->start; pos < work->end; pos++) {
            int i = pos / pool->n, j = pos % pool->n;
            work->target[i][j] = calculation(pool->n, work->left[i], work->right[j]);
        }
        if (work->end == pool->n * pool->n) { // end == n^2 means one request's end
            pthread_mutex_lock(&pool->count_lock);
            pool->req_done++;
            if (pool->req_done == pool->req_sent)
                pthread_cond_broadcast(&pool->sync);
            pthread_mutex_unlock(&pool->count_lock);
            free(work);
        }
    }
    debug_printf("[BACKEND:%ld] cleanup unlocking...\n", tid);
    pthread_mutex_unlock(&pool->work_lock);
    debug_printf("[BACKEND:%ld] bye\n", tid);
    pthread_exit(NULL);
}


struct tpool *tpool_init(int num_threads, int n) {
    Tpool* ret = (Tpool*)malloc(sizeof(Tpool)); // deallocate in tpool_destroy
    ret->n = n, ret->done = 0, ret->req_sent = 0, ret->req_done = 0;
    ret->num_threads = num_threads;
    ret->requests_head = NULL, ret->workers_head = NULL;
    ret->requests_tail = NULL, ret->workers_tail = NULL;
    pthread_mutex_init(&ret->req_lock, NULL);
    pthread_mutex_init(&ret->work_lock, NULL);
    pthread_mutex_init(&ret->count_lock, NULL);
    pthread_cond_init(&ret->req_empty, NULL);
    pthread_cond_init(&ret->work_empty, NULL);
    pthread_cond_init(&ret->sync, NULL);

    ret->backend = (pthread_t*)malloc(sizeof(pthread_t)*num_threads); // deallocate in tpool_destroy
    pthread_create(&ret->frontend, NULL, frontend_routine, (void*)ret);
    for (int i = 0; i < num_threads; i++) {
        pthread_create(&ret->backend[i], NULL, backend_routine, (void*)ret);
    }

    return ret;
}

void tpool_request(struct tpool *pool, Matrix a, Matrix b, Matrix c,
                   int num_works) {
    Request* request = request_init(a, b, c, num_works);
    pthread_mutex_lock(&pool->req_lock);
    requests_enq(pool, request);
    pthread_mutex_lock(&pool->count_lock);
    pool->req_sent++;
    pthread_mutex_unlock(&pool->count_lock);
    pthread_cond_broadcast(&pool->req_empty);
    pthread_mutex_unlock(&pool->req_lock);
    debug_printf("[%s] enqueue\n", __func__);
    return;
}

void tpool_synchronize(struct tpool *pool) {
    /*
     * Waits for when req_sent == req_done
     */
    debug_printf("[%s] synching...\n", __func__);
    pthread_mutex_lock(&pool->count_lock);
    while ((pool->req_sent == 0 || pool->req_done == 0) || pool->req_sent != pool->req_done) {
        pthread_cond_wait(&pool->sync, &pool->count_lock);
    }
    debug_printf("[%s] req_sent = %d, req_done = %d\n", __func__, pool->req_sent, pool->req_done);
    // debug_printf("[%s] req_sent = %d, req_done = %d\n", __func__, pool->req_sent, pool->req_done);
    pool->req_sent = 0, pool->req_done = 0;
    pthread_mutex_unlock(&pool->count_lock);
    debug_printf("[%s] synced\n", __func__);
}

void tpool_destroy(struct tpool *pool) {
    pthread_mutex_lock(&pool->req_lock);
    pthread_mutex_lock(&pool->work_lock);
    pool->done = 1;
    pthread_cond_broadcast(&pool->work_empty);
    pthread_cond_broadcast(&pool->req_empty);
    pthread_mutex_unlock(&pool->work_lock);
    pthread_mutex_unlock(&pool->req_lock);

    debug_printf("[%s] destroy signal broadcasted\n", __func__);

    for (int i = 0; i < pool->num_threads; i++) {
        debug_printf("[%s] joinning thread %d/%d: id = %ld\n", __func__, i, pool->num_threads, pool->backend[i] % 1021);
        pthread_join(pool->backend[i], NULL);
    }
    pthread_join(pool->frontend, NULL);

    debug_printf("[%s] destroying...\n", __func__);
    free(pool->backend);
    pthread_mutex_destroy(&pool->req_lock);
    pthread_mutex_destroy(&pool->work_lock);
    pthread_cond_destroy(&pool->req_empty);
    pthread_cond_destroy(&pool->work_empty);
    free(pool);
}

Request* request_init(Matrix a, Matrix b, Matrix c, int num_works) {
    Request* ret = (Request*)malloc(sizeof(Request)); // deallocate in frontend_routine
    ret->left = a, ret->right = b, ret->target = c;
    ret->num_works = num_works;
    return ret;
}

void requests_enq(Tpool* pool, Request* req) {
    if (pool->requests_head == NULL  || pool->requests_tail == NULL)
        pool->requests_head = pool->requests_tail = req;
    else {
        pool->requests_tail->next = req;
        pool->requests_tail = req; 
    }
}

Request* requests_deq(Tpool* pool) {
    // Lock the request queue
    Request* req = pool->requests_head;
    pool->requests_head = req->next;

    if (pool->requests_head == NULL) {
        // If the queue becomes empty, update the tail pointer
        pool->requests_tail = NULL;
    }

    return req; // Return the dequeued request
}

Work *work_init(Request *request) {
    Work* ret = (Work*)malloc(sizeof(Work)); // deallocate in backend_routine
    ret->left = request->left;
    ret->right = request->right;
    ret->target = request->target;
    ret->next = NULL;
    return ret;
}

void works_enq(Tpool* pool, Work* work) {
    // Add the work to the end of the queue
    work->next = NULL; // Ensure the new work has no next element
    if (pool->workers_tail == NULL || pool->workers_head == NULL) {
        // Queue is empty
        pool->workers_head = pool->workers_tail = work;
    } else {
        // Queue is not empty, add to the tail
        pool->workers_tail->next = work;
        pool->workers_tail = work;
    }
}

Work* works_deq(Tpool* pool) {
    // Remove the work from the front of the queue
    Work* work = pool->workers_head;
    pool->workers_head = work->next;

    if (pool->workers_head == NULL) {
        // If the queue becomes empty, update the tail pointer
        pool->workers_tail = NULL;
    }

    return work; // Return the dequeued work
}

Matrix transpose(Matrix mat, int n) {
    for (int i = 0; i < n; i++) {
        for (int j = i + 1; j < n; j++) {
            // Swap mat[i][j] and mat[j][i]
            int temp = mat[i][j];
            mat[i][j] = mat[j][i];
            mat[j][i] = temp;
        }
    }
    return mat; // Return the modified matrix
}
