import React, { useState, useEffect, useCallback, useRef } from 'react';
import { Slider, RadioGroup, FormControlLabel, Radio, ToggleButtonGroup, ToggleButton, Tooltip } from '@mui/material';
import { createTheme, ThemeProvider } from '@mui/material/styles';
import GrayscaleImageProcessor from './GrayscaleImageProcessor'; // Import the new component

// Custom Button component
const Button = ({ children, className = '', ...props }) => (
  <button
    className={`px-4 py-2 bg-white text-black border border-white rounded transition-colors duration-300 hover:bg-black hover:text-white focus:outline-none focus:ring-2 focus:ring-white focus:ring-opacity-50 ${className}`}
    {...props}
  >
    {children}
  </button>
);

const theme = createTheme({
  palette: {
    primary: {
      main: '#ffffff',
    },
    secondary: {
      main: '#000000',
    },
  },
});

function PythonExecutor() {
  const [pyodide, setPyodide] = useState(null);
  const [image, setImage] = useState(null);
  const [output, setOutput] = useState('');
  const [loading, setLoading] = useState(true);
  const [processedImageUrl, setProcessedImageUrl] = useState(null);
  const [params, setParams] = useState({
    minRadius: 5,
    maxRadius: 50,
    minDist: 20,
    cannyThreshold1: 50,
    cannyThreshold2: 200,
    contrast: 1,
    brightness: 0,
    method: 'w', // Set Watershed Transform as default
    threshold: 128, // Add a new threshold parameter
  });
  const [tool, setTool] = useState('none');
  const canvasRef = useRef(null);
  const [manualColonies, setManualColonies] = useState([]);
  const [excludedColonies, setExcludedColonies] = useState([]);
  const [detectedColonies, setDetectedColonies] = useState([]);
  const [isManualMode, setIsManualMode] = useState(false);
  const [manualMode, setManualMode] = useState('none'); // 'none', 'add', or 'delete'
  const [localImageUrl, setLocalImageUrl] = useState(null);
  const fileInputRef = useRef(null);
  const [showGrayscaleProcessor, setShowGrayscaleProcessor] = useState(false); // State to toggle grayscale processor

  useEffect(() => {
    const script = document.createElement('script');
    script.src = 'https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js';
    script.async = true;
    script.onload = async () => {
      const pyodideInstance = await window.loadPyodide({
        indexURL: 'https://cdn.jsdelivr.net/pyodide/v0.24.1/full/'
      });
      
      await pyodideInstance.loadPackage(['numpy', 'opencv-python']);
      
      const counterScript = `
        import numpy as np
        import cv2
        import base64

        def preprocess_image(image, contrast, brightness):
            adjusted = cv2.convertScaleAbs(image, alpha=contrast, beta=brightness)
            return adjusted

        def count_colonies_hough(image, min_radius, max_radius, min_dist, canny_threshold1, canny_threshold2):
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            blurred = cv2.GaussianBlur(gray, (5, 5), 0)
            edges = cv2.Canny(blurred, canny_threshold1, canny_threshold2)
            circles = cv2.HoughCircles(edges, cv2.HOUGH_GRADIENT, dp=1, minDist=min_dist,
                                         param1=50, param2=30, minRadius=min_radius, maxRadius=max_radius)
            
            if circles is not None:
                circles = np.round(circles[0, :]).astype("int")
                for (x, y, r) in circles:
                    cv2.circle(image, (x, y), r, (0, 255, 0), 2)
                return len(circles), image, circles.tolist()
            return 0, image, []

        def count_colonies_watershed(image):
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
            
            kernel = np.ones((3,3), np.uint8)
            opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
            
            sure_bg = cv2.dilate(opening, kernel, iterations=3)
            dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
            _, sure_fg = cv2.threshold(dist_transform, 0.7*dist_transform.max(), 255, 0)
            
            sure_fg = np.uint8(sure_fg)
            unknown = cv2.subtract(sure_bg, sure_fg)
            
            _, markers = cv2.connectedComponents(sure_fg)
            markers = markers + 1
            markers[unknown == 255] = 0
            
            markers = cv2.watershed(image, markers)
            image[markers == -1] = [255, 0, 0]
            
            colony_count = len(np.unique(markers)) - 2
            
            colonies = []
            for label in np.unique(markers):
                if label > 0:
                    mask = np.zeros(gray.shape, dtype="uint8")
                    mask[markers == label] = 255
                    contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
                    M = cv2.moments(contours[0])
                    cX = int(M["m10"] / M["m00"])
                    cY = int(M["m01"] / M["m00"])
                    colonies.append([cX, cY, 10])
            
            return colony_count, image, colonies

        def count_colonies_otsu(image):
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            # Otsu's thresholding
            _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
            
            # Find contours
            contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                cv2.drawContours(image, [contour], -1, (0, 255, 0), 2)
            
            return len(contours), image, [(int(cv2.moments(c)["m10"] / cv2.moments(c)["m00"]), int(cv2.moments(c)["m01"] / cv2.moments(c)["m00"]), 10) for c in contours if cv2.moments(c)["m00"] != 0]

        def count_colonies_contour(image):
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
            
            contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                cv2.drawContours(image, [contour], -1, (0, 255, 0), 2)
            
            return len(contours), image, [(int(cv2.moments(c)["m10"] / cv2.moments(c)["m00"]), int(cv2.moments(c)["m01"] / cv2.moments(c)["m00"]), 10) for c in contours if cv2.moments(c)["m00"] != 0]

        def process_image(image_path, params):
            image = cv2.imread(image_path)
            # Scale down the image if it exceeds 960x960 pixels
            if image.shape[1] > 960 or image.shape[0] > 960:
                scale_factor = min(960 / image.shape[1], 960 / image.shape[0])
                new_size = (int(image.shape[1] * scale_factor), int(image.shape[0] * scale_factor))
                image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
                
            image = preprocess_image(image, params['contrast'], params['brightness']);
            
            # Apply thresholding correctly
            threshold_value = int(params['threshold'])
            _, thresh = cv2.threshold(image, threshold_value, 255, cv2.THRESH_BINARY)
            
            if params['method'] == 'h':
                count, processed_image, colonies = count_colonies_hough(
                    image, params['minRadius'], params['maxRadius'], params['minDist'],
                    params['cannyThreshold1'], params['cannyThreshold2']
                )
            elif params['method'] == 'o':
                count, processed_image, colonies = count_colonies_otsu(image)
            elif params['method'] == 'c':  # Added condition for Contour Detection
                count, processed_image, colonies = count_colonies_contour(image)
            else:
                count, processed_image, colonies = count_colonies_watershed(image);
            
            _, buffer = cv2.imencode('.png', processed_image)
            img_base64 = base64.b64encode(buffer).decode('utf-8')
            
            return count, img_base64, colonies
      `;
      
      pyodideInstance.runPython(counterScript);
      
      setPyodide(pyodideInstance);
      setLoading(false);
    };
    document.body.appendChild(script);

    return () => {
      document.body.removeChild(script);
    };
  }, []);

  const handleImageUpload = (event) => {
    const file = event.target.files[0];
    setImage(file);
    setManualColonies([]);
    setExcludedColonies([]);
    setLocalImageUrl(URL.createObjectURL(file));
  };

  const processImage = useCallback(async () => {
    if (pyodide && image) {
      try {
        const reader = new FileReader();
        reader.onload = async (e) => {
          const imageData = new Uint8Array(e.target.result);
          pyodide.FS.writeFile(image.name, imageData);
          
          const result = await pyodide.runPythonAsync(`
            params = ${JSON.stringify(params)}
            _, img_base64, colonies = process_image("${image.name}", params)
            [img_base64, colonies]
          `);
          
          setProcessedImageUrl(`data:image/png;base64,${result[0]}`);
          setDetectedColonies(result[1]);
          redrawColonies();
        };
        reader.readAsArrayBuffer(image);
      } catch (error) {
        console.error('Error processing image:', error);
      }
    }
  }, [pyodide, image, params]);

  useEffect(() => {
    processImage();
  }, [processImage]);

  const runColonyCounter = async () => {
    if (pyodide && image) {
      try {
        setOutput('Counting colonies...');
        
        const result = await pyodide.runPythonAsync(`
          params = ${JSON.stringify(params)}
          count, _, _ = process_image("${image.name}", params)
          count
        `);
        
        const manualAddCount = manualColonies.filter(colony => colony.type === 'add').length;
        const manualDeleteCount = manualColonies.filter(colony => colony.type === 'delete').length;
        const totalCount = result + manualAddCount - manualDeleteCount;
        setOutput(`Number of colonies: ${totalCount} (Automatic: ${result}, Manual Add: ${manualAddCount}, Manual Delete: ${manualDeleteCount})`);
      } catch (error) {
        setOutput(`Error: ${error.message}`);
      }
    }
  };

  const handleParamChange = (name, value) => {
    setParams(prevParams => ({
      ...prevParams,
      [name]: value
    }));
    
    if (name === 'threshold') {
        processImage(); // Reprocess the image when the threshold changes
    }
  };

  const drawColony = (x, y, color) => {
    const canvas = canvasRef.current;
    const ctx = canvas.getContext('2d');
    ctx.beginPath();
    ctx.arc(x, y, 5, 0, 2 * Math.PI);
    ctx.fillStyle = color;
    ctx.fill();
  };

  const redrawColonies = useCallback(() => {
    const canvas = canvasRef.current;
    if (!canvas) return;
    const ctx = canvas.getContext('2d');
    const img = new Image();
    img.onload = () => {
      canvas.width = img.width;
      canvas.height = img.height;
      ctx.drawImage(img, 0, 0);
      
      detectedColonies.forEach(([x, y, r]) => {
        if (!excludedColonies.some(([ex, ey]) => Math.abs(ex - x) < 5 && Math.abs(ey - y) < 5)) {
          ctx.beginPath();
          ctx.arc(x, y, r, 0, 2 * Math.PI);
          ctx.strokeStyle = 'green';
          ctx.lineWidth = 2;
          ctx.stroke();
        }
      });
      
      manualColonies.forEach(({ x, y, type }) => {
        ctx.beginPath();
        ctx.arc(x, y, 5, 0, 2 * Math.PI);
        ctx.fillStyle = type === 'add' ? 'blue' : 'red';
        ctx.fill();
      });
    };
    img.src = processedImageUrl;
  }, [processedImageUrl, detectedColonies, excludedColonies, manualColonies]);

  useEffect(() => {
    if (processedImageUrl) {
      redrawColonies();
    }
  }, [processedImageUrl, detectedColonies, manualColonies, excludedColonies, redrawColonies]);

  const toggleManualMode = () => {
    setIsManualMode(!isManualMode);
    setManualMode('none');
  };

  const handleCanvasClick = (event) => {
    if (!isManualMode || manualMode === 'none') return;

    const canvas = canvasRef.current;
    const rect = canvas.getBoundingClientRect();
    const x = event.clientX - rect.left;
    const y = event.clientY - rect.top;

    if (manualMode === 'add') {
      setManualColonies([...manualColonies, { x, y, type: 'add' }]);
      drawColony(x, y, 'blue');
    } else if (manualMode === 'delete') {
      // Check if clicking on an existing manual colony
      const existingIndex = manualColonies.findIndex(colony => 
        Math.sqrt((colony.x - x) ** 2 + (colony.y - y) ** 2) <= 10
      );

      if (existingIndex !== -1) {
        // Remove the existing colony
        setManualColonies(manualColonies.filter((_, index) => index !== existingIndex));
        redrawColonies();
      } else {
        // Add a negative colony
        setManualColonies([...manualColonies, { x, y, type: 'delete' }]);
        drawColony(x, y, 'red');
      }
    }
  };

  const handleFileButtonClick = () => {
    fileInputRef.current.click();
  };

  const toggleGrayscaleProcessor = () => {
    setShowGrayscaleProcessor(!showGrayscaleProcessor);
  };

  if (loading) {
    return <div>Loading Pyodide and dependencies...</div>;
  }

  return (
    <div className="flex flex-col items-center w-full max-w-6xl mx-auto">
      <h2 className="text-3xl font-bold mb-6">Colony Counter</h2>
      <input
        type="file"
        accept="image/*"
        onChange={handleImageUpload}
        className="hidden"
        ref={fileInputRef}
      />
      {/* Flex container for buttons centered with minimal spacing */}
      <div className="flex justify-center w-full mb-4 relative"> {/* Added relative positioning */}
        <Tooltip title="Applies a threshold filter to the image, greatly increases colony detection sensitivity" arrow>
          <Button
            onClick={toggleGrayscaleProcessor}
            className="mr-2" // Add a small margin to the right for spacing
          >
            Image Preprocessing
          </Button>
        </Tooltip>
        <Button
          onClick={handleFileButtonClick}
        >
          Choose File
        </Button>
      </div>

      {showGrayscaleProcessor && <GrayscaleImageProcessor />} {/* Render the new component, ensure it does not display grayscale image */}
      
      <div className="flex w-full justify-center items-start">
        <div className="w-1/2 pr-4">
          <div className="w-full mb-4 text-left">
            <ThemeProvider theme={theme}>
              <label>Contrast</label>
              <Slider
                value={params.contrast}
                onChange={(_, value) => handleParamChange('contrast', value)}
                min={0.5}
                max={3}
                step={0.1}
                valueLabelDisplay="auto"
                aria-labelledby="contrast-slider"
                sx={{
                  '& .MuiSlider-thumb': {
                    backgroundColor: 'white', // Always white
                  },
                  '& .MuiSlider-track': {
                    backgroundColor: 'white', // Always white
                  },
                  '& .MuiSlider-rail': {
                    backgroundColor: 'rgba(255, 255, 255, 0.5)',
                  },
                }}
              />
              
              <label>Brightness</label>
              <Slider
                value={params.brightness}
                onChange={(_, value) => handleParamChange('brightness', value)}
                min={-100}
                max={100}
                valueLabelDisplay="auto"
                aria-labelledby="brightness-slider"
                sx={{
                  '& .MuiSlider-thumb': {
                    backgroundColor: 'white', // Always white
                  },
                  '& .MuiSlider-track': {
                    backgroundColor: 'white', // Always white
                  },
                  '& .MuiSlider-rail': {
                    backgroundColor: 'rgba(255, 255, 255, 0.5)',
                  },
                }}
              />

              <label>Min Distance</label>
              <Slider
                value={params.minDist}
                onChange={(_, value) => handleParamChange('minDist', value)}
                min={1}
                max={100}
                valueLabelDisplay="auto"
                aria-labelledby="min-dist-slider"
                disabled={params.method !== 'h'} // Disable if method is not Hough Circle Transform
                sx={{
                  '& .MuiSlider-thumb': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-track': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-rail': {
                    backgroundColor: 'rgba(255, 255, 255, 0.5)',
                  },
                }}
              />
              
              <label>Min Radius</label>
              <Slider
                value={params.minRadius}
                onChange={(_, value) => handleParamChange('minRadius', value)}
                min={1}
                max={50}
                valueLabelDisplay="auto"
                aria-labelledby="min-radius-slider"
                disabled={params.method !== 'h'} // Disable if method is not Hough Circle Transform
                sx={{
                  '& .MuiSlider-thumb': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-track': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-rail': {
                    backgroundColor: 'rgba(255, 255, 255, 0.5)',
                  },
                }}
              />
              
              <label>Max Radius</label>
              <Slider
                value={params.maxRadius}
                onChange={(_, value) => handleParamChange('maxRadius', value)}
                min={10}
                max={200}
                valueLabelDisplay="auto"
                aria-labelledby="max-radius-slider"
                disabled={params.method !== 'h'} // Disable if method is not Hough Circle Transform
                sx={{
                  '& .MuiSlider-thumb': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-track': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-rail': {
                    backgroundColor: 'rgba(255, 255, 255, 0.5)',
                  },
                }}
              />
              
              <label>Canny Threshold 1</label>
              <Slider
                value={params.cannyThreshold1}
                onChange={(_, value) => handleParamChange('cannyThreshold1', value)}
                min={0}
                max={255}
                valueLabelDisplay="auto"
                aria-labelledby="canny-threshold1-slider"
                disabled={params.method !== 'h'} // Disable if method is not Hough Circle Transform
                sx={{
                  '& .MuiSlider-thumb': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-track': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-rail': {
                    backgroundColor: 'rgba(255, 255, 255, 0.5)',
                  },
                }}
              />
              
              <label>Canny Threshold 2</label>
              <Slider
                value={params.cannyThreshold2}
                onChange={(_, value) => handleParamChange('cannyThreshold2', value)}
                min={0}
                max={255}
                valueLabelDisplay="auto"
                aria-labelledby="canny-threshold2-slider"
                disabled={params.method !== 'h'} // Disable if method is not Hough Circle Transform
                sx={{
                  '& .MuiSlider-thumb': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-track': {
                    backgroundColor: params.method === 'h' ? 'white' : 'rgba(255, 255, 255, 0.5)',
                  },
                  '& .MuiSlider-rail': {
                    backgroundColor: 'rgba(255, 255, 255, 0.5)',
                  },
                }}
              />
              
              <RadioGroup
                value={params.method}
                onChange={(e) => handleParamChange('method', e.target.value)}
              >
                <FormControlLabel value="w" control={<Radio />} label="Watershed Transform (Recommended)" />
                <FormControlLabel value="h" control={<Radio />} label="Hough Circle Transform" />
                <FormControlLabel value="o" control={<Radio />} label="Otsu's Thresholding" />
              </RadioGroup>
              
              <ToggleButtonGroup
                value={tool}
                exclusive
                onChange={(_, newTool) => setTool(newTool)}
                aria-label="colony manipulation tool"
              >
                <ToggleButton value="none" aria-label="no tool">
                  None
                </ToggleButton>
                <ToggleButton value="add" aria-label="add colony">
                  Add Colony
                </ToggleButton>
                <ToggleButton value="exclude" aria-label="exclude colony">
                  Exclude Colony
                </ToggleButton>
              </ToggleButtonGroup>
            </ThemeProvider>
          </div>
          
          <Button
            onClick={toggleManualMode}
            className="mb-4 mr-4"
          >
            {isManualMode ? "Exit Manual Mode" : "Manual Mode"}
          </Button>

          {isManualMode && (
            <div className="mb-4">
              <Button
                onClick={() => setManualMode('add')}
                className={`mr-2 ${manualMode === 'add' ? 'bg-black text-white' : ''}`}
              >
                Add Colony
              </Button>
              <Button
                onClick={() => setManualMode('delete')}
                className={manualMode === 'delete' ? 'bg-black text-white' : ''}
              >
                Delete/Negate Colony
              </Button>
            </div>
          )}

          <Button
            onClick={runColonyCounter}
            disabled={!image}
            className={`${!image ? 'opacity-50 cursor-not-allowed' : ''}`}
          >
            Count Colonies
          </Button>
        </div>

        {processedImageUrl && (
          <div className="w-1/2 pl-4">
            <h3 className="text-xl font-bold mb-2">Processed Image:</h3>
            <canvas
              ref={canvasRef}
              onClick={handleCanvasClick}
              style={{ cursor: isManualMode ? 'crosshair' : 'default' }}
            />
          </div>
        )}
      </div>

      <div className="w-full pl-4">
        <div className="w-full p-4 bg-gray-800 text-white rounded-md mb-4">
          <h3 className="text-xl font-bold mb-2">Output:</h3>
          <pre className="whitespace-pre-wrap">{output}</pre>
        </div>
      </div>
    </div>
  );
}

export default PythonExecutor;