#include "common.h"
#include "gui.h" 
#include <stdio.h>
#include <string.h>
#include <stdbool.h>
#include <tice.h>    
#include <keypadc.h> 

#define MAX_QM_TERMS 1024  

typedef struct {
    char bits[MAX_VARS + 1]; 
    bool used;               
} Term;

void int_to_bits(int val, int num_vars, char* out) {
    for (int i = 0; i < num_vars; i++) {
        int bit = (val >> (num_vars - 1 - i)) & 1;
        out[i] = bit ? '1' : '0';
    }
    out[num_vars] = '\0';
}

int diff_index(const char* t1, const char* t2, int num_vars) {
    int diff_cnt = 0;
    int idx = -1;
    for (int i = 0; i < num_vars; i++) {
        if (t1[i] != t2[i]) {
            diff_cnt++;
            idx = i;
        }
    }
    return (diff_cnt == 1) ? idx : -1;
}

bool list_contains(Term* list, int count, const char* bits) {
    for (int i = 0; i < count; i++) {
        if (strcmp(list[i].bits, bits) == 0) return true;
    }
    return false;
}

bool simplify_expression(TruthTable* table, char* result) {
    static Term terms[MAX_QM_TERMS];
    static Term next_terms[MAX_QM_TERMS];
    static char primes[MAX_QM_TERMS][MAX_VARS + 1];
    
    char msg[40]; 
    draw_status("Starting optimization...", 224); 

    int count = 0;
    int next_count = 0;
    
    bool all_zero = true;
    bool all_one = true;
    int limit = 1 << table->num_vars;
    
    for(int i=0; i<limit; i++) {
        if(table->output[i]) all_zero = false;
        else all_one = false;
    }
    
    if(all_zero) { strcpy(result, "0"); return true; }
    if(all_one)  { strcpy(result, "1"); return true; }

    for (int i = 0; i < limit; i++) {
        if (os_GetCSC() == sk_Clear) {
            strcpy(result, "ABORTED!");
            return false;
        }

        if (table->output[i]) {
            if (count >= MAX_QM_TERMS) {
                strcpy(result, "ERROR: Memory!");
                return false;
            }
            int_to_bits(i, table->num_vars, terms[count].bits);
            terms[count].used = false;
            count++;
        }
    }
    
    sprintf(msg, "Minterms: %d", count);
    draw_status(msg, 0);

    bool changed = true;
    int primes_count = 0;
    int pass = 1;

    while (changed) {
        if (os_GetCSC() == sk_Clear) {
            strcpy(result, "ABORTED!");
            return false; 
        }

        sprintf(msg, "QM Round %d (Terms: %d)", pass, count);
        draw_status(msg, 224); 
        pass++;

        changed = false;
        next_count = 0;
        
        for(int i=0; i<count; i++) terms[i].used = false;

        for (int i = 0; i < count; i++) {
            for (int j = i + 1; j < count; j++) {
                int diff = diff_index(terms[i].bits, terms[j].bits, table->num_vars);
                
                if (diff != -1) {
                    terms[i].used = true;
                    terms[j].used = true;
                    changed = true;

                    char new_bits[MAX_VARS + 1];
                    strcpy(new_bits, terms[i].bits);
                    new_bits[diff] = '-';

                    if (!list_contains(next_terms, next_count, new_bits)) {
                        if(next_count < MAX_QM_TERMS) {
                            strcpy(next_terms[next_count].bits, new_bits);
                            next_terms[next_count].used = false;
                            next_count++;
                        }
                    }
                }
            }
        }

        for (int i = 0; i < count; i++) {
            if (!terms[i].used) {
                bool exists = false;
                for(int p=0; p<primes_count; p++) {
                    if(strcmp(primes[p], terms[i].bits) == 0) exists = true;
                }
                if(!exists && primes_count < MAX_QM_TERMS) {
                    strcpy(primes[primes_count], terms[i].bits);
                    primes_count++;
                }
            }
        }

        if (changed) {
            if (next_count == 0) break;
            count = next_count;
            for(int i=0; i<count; i++) {
                strcpy(terms[i].bits, next_terms[i].bits);
                terms[i].used = false;
            }
        }
    }

    draw_status("Generating output...", 0);
    result[0] = '\0'; 

    for (int i = 0; i < primes_count; i++) {
        if (i > 0) strcat(result, " + ");

        char* bits = primes[i];
        bool first_char = true; 
        bool is_term_one = true; 

        for (int j = 0; j < table->num_vars; j++) {
            if (bits[j] != '-') {
                is_term_one = false;
                if (!first_char) strcat(result, "*"); 
                
                char var_name[2] = { 'A' + j, '\0' };
                strcat(result, var_name);

                if (bits[j] == '0') strcat(result, "!");

                first_char = false;
            }
        }
        if(is_term_one) strcat(result, "1");
    }
    
    return true; 
}