import React, { useEffect, useState } from 'react';
import axios from 'axios';
import { Line, Bar } from 'react-chartjs-2';
import { API_URL } from '../../config';
import { Box } from '@mui/material';

import {
    Chart as ChartJS,
    CategoryScale,
    LinearScale,
    BarElement,
    LineElement,
    PointElement,
    Title,
    Tooltip,
    Legend,
} from 'chart.js';

// Register Chart.js components
ChartJS.register(CategoryScale, LinearScale, BarElement, LineElement, PointElement, Title, Tooltip, Legend);

const ResNet50Metrics = ({model}) => {
    const [confusionMatrix, setConfusionMatrix] = useState([]);
    const [categoryAccuracy, setCategoryAccuracy] = useState({});
    const [epochAccuracy, setEpochAccuracy] = useState({});
    const [loading, setLoading] = useState(true);
    const [error, setError] = useState(null);

    useEffect(() => {
        const fetchData = async () => {
            try {
                const [
                    confusionMatrixResponse,
                    categoryAccuracyResponse,
                    epochAccuracyResponse
                ] = await Promise.all([
                    axios.get(`${API_URL}/api/metrics/confusion_matrix/${model}`),
                    axios.get(`${API_URL}/api/metrics/category_accuracy/${model}`),
                    axios.get(`${API_URL}/api/metrics/epoch_accuracy/${model}`),
                ]);

                setConfusionMatrix(confusionMatrixResponse.data.confusion_matrix);
                setCategoryAccuracy(categoryAccuracyResponse.data.category_accuracy);
                setEpochAccuracy(epochAccuracyResponse.data.epoch_accuracy);
                setLoading(false);
            } catch (error) {
                setError(error);
                setLoading(false);
            }
        };

        fetchData();
    }, [model]);

    if (loading) {
        return <div>Loading...</div>;
    }

    if (error) {
        return <div>Error: {error.message}</div>;
    }

    const categoryLabels = Object.keys(categoryAccuracy);
    const categoryData = Object.values(categoryAccuracy);

    const epochLabels = epochAccuracy.epochs;
    const trainData = epochAccuracy.train_accuracy;
    const testData = epochAccuracy.test_accuracy;

    const categoryChartData = {
        labels: categoryLabels,
        datasets: [
            {
                label: 'Category Accuracy',
                data: categoryData,
                backgroundColor: 'rgba(75, 192, 192, 0.6)',
            },
        ],
    };

    const epochChartData = {
        labels: epochLabels,
        datasets: [
            {
                label: 'Training Accuracy',
                data: trainData,
                borderColor: 'rgba(75, 192, 192, 1)',
                fill: false,
            },
            {
                label: 'Testing Accuracy',
                data: testData,
                borderColor: 'rgba(153, 102, 255, 1)',
                fill: false,
            },
        ],
    };

    const renderCell = (cell, rowIndex, cellIndex) => {
        let bgColor = '';
        let text = '';
        if (rowIndex === cellIndex) {
            bgColor = 'green';
            text = 'True Positive';
        } else if (rowIndex > cellIndex) {
            bgColor = 'red';
            text = 'False Negative';
        } else {
            bgColor = 'orange';
            text = 'False Positive';
        }

        return (
            <td key={cellIndex} style={{ backgroundColor: bgColor, color: 'white', textAlign: 'center', padding: '10px' }}>
                {cell} <br /> <small>{text}</small>
            </td>
        );
    };

    return (
        <Box sx={{
            maxWidth: '1000px',
            margin: 'auto'
        }}>
            <h1>{model.charAt(0).toUpperCase() + model.slice(1)} Metrics</h1>
            
            <Box id="metrics-category-accuracy" sx={{
                backgroundColor: 'white',
                padding: '10px',
                borderRadius: '10px',
                marginBottom: '10px'
            }}>
                <h2>Category Accuracy</h2>
                <Bar data={categoryChartData} />
            </Box> 
            <Box id="metrics-epoch-accuracy" sx={{
                backgroundColor: 'white',
                padding: '10px',
                borderRadius: '10px',
                marginBottom: '10px'
            }}>
                <h2>Epoch Accuracy</h2>
                <Line data={epochChartData} />
            </Box>
            <Box id="metrics-confusion-matrix" sx={{
                backgroundColor: 'white',
                padding: '10px',
                borderRadius: '10px',
                marginBottom: '10px'
            }}>
                <h2>Confusion Matrix</h2>
                <table>
                    <thead>
                        <tr>
                            <th></th>
                            {confusionMatrix.map((_, index) => (
                                <th key={index}>Predicted {index}</th>
                            ))}
                        </tr>
                    </thead>
                    <tbody>
                        {confusionMatrix.map((row, rowIndex) => (
                            <tr key={rowIndex}>
                                <td>Actual {rowIndex}</td>
                                {row.map((cell, cellIndex) => renderCell(cell, rowIndex, cellIndex))}
                            </tr>
                        ))}
                    </tbody>
                </table>
            </Box>
        </Box>
    );
};

export default ResNet50Metrics;
