import { Tensor } from "onnxruntime-web";
import * as tf from '@tensorflow/tfjs';
import npyjs from "npyjs";

import {
  clickType,
  modeDataProps,
  modelInputProps,
  queryEraseModelProps,
  queryModelReturnTensorsProps,
  setParmsandQueryEraseModelProps,
  setParmsandQueryModelProps,
} from "./Interface";

//const API_ENDPOINT = 'https://model-zoo.metademolab.com/predictions/segment_everything_box_model' //'http://localhost:8080/api/embedding' //process.env.API_ENDPOINT; //'http://ai.sandbox.caas.sgkincdev.com/v1/models/image-segmentation:predict'//
const API_ENDPOINT = 'https://ai.caassandbox.com/v1/models/image-segmentation/embedding'//'http://kubeflow-temp.caassandbox.com/v1/models/image-segmentation:predict'//
const ALL_MASK_API_ENDPOINT = 'https://ai.caassandbox.com/v1/models/image-segmentation/auto_segmentation' // 'http://kubeflow-temp.caassandbox.com/v1/models/image-segmentation:predict'//'https://model-zoo.metademolab.com/predictions/automatic_masks'
const ERASE_API_ENDPOINT = process.env.ERASE_API_ENDPOINT;
const API_TOKEN_ENDPOINT = 'https://pu.caassandbox.com/api/auth/token'
const getEmbedding = async (encoder, imgData) => {
  console.log('herrrre')
  //@ts-ignore
  const resizedTensor = await Tensor.fromImage(imgData, { resizedWidth: 1024, resizedHeight: 1024 });
  const resizeImage = resizedTensor.toImageData();
  let imageDataTensor: any = await Tensor.fromImage(resizeImage);
  console.log("image data tensor:", imageDataTensor);
  let tf_tensor = tf.tensor(imageDataTensor.data, imageDataTensor.dims);
  tf_tensor = tf_tensor.reshape([3, 1024, 1024]);
  tf_tensor = tf_tensor.transpose([1, 2, 0]).mul(255);
  imageDataTensor = new Tensor(tf_tensor.dataSync(), tf_tensor.shape);
  const feeds = { "input_image": imageDataTensor };
  let start = Date.now();
  let results;
  try {
    results = await encoder.run(feeds);
    console.log("Encoding result:", results);
  } catch (error) {
    console.log(`caught error: ${error}`)
  }
  let end = Date.now();
  let time_taken = (end - start) / 1000;
  console.log(`Computing image embedding took ${time_taken} seconds`);
  let image_embeddings = results.image_embeddings;
  return image_embeddings;
}
// Function to convert Blob to base64
const blobToBase64 = (blob) => {
  return new Promise((resolve, reject) => {
    const reader = new FileReader();
    reader.onloadend = () => {
      //@ts-ignore
      const base64String = reader.result.split(',')[1];
      resolve(base64String);
    };
    reader.onerror = reject;
    reader.readAsDataURL(blob);
  });
}
const setParmsandQueryModel = ({
  width,
  height,
  uploadScale,
  imgData,
  handleSegModelResults,
  handleAllModelResults,
  imgName,
  shouldDownload,
  shouldNotFetchAllModel,
  encoder,
}: setParmsandQueryModelProps) => {
  const canvas = document.createElement("canvas");
  console.log(width,
    height,
    uploadScale)
  canvas.width = Math.round(width * uploadScale);
  canvas.height = Math.round(height * uploadScale);
  console.log(canvas.width, canvas.height)
  const ctx = canvas.getContext("2d");
  if (!ctx) return;
  ctx.drawImage(imgData, 0, 0, canvas.width, canvas.height);

  // if (encoder) {
  //   Promise.resolve(getEmbedding(encoder, imgData)).then(
  //     (image_embeddings) => {
  //       const lowResTensor = new Tensor("float32", image_embeddings.data, [1, 256, 64, 64]);
  //       handleSegModelResults({
  //         tensor: lowResTensor,
  //       });
  //     }
  //   )

  //   return;
  // }
  canvas.toBlob(
    (blob) => {
      blob &&
        queryModelReturnTensors({
          blob,
          handleSegModelResults,
          handleAllModelResults,
          image_height: canvas.height,
          imgName,
          shouldDownload,
          shouldNotFetchAllModel,
        });
    },
    "image/jpeg",
    1.0
  );
};

const queryModelReturnTensors = async ({
  blob,
  handleSegModelResults,
  handleAllModelResults,
  image_height, // Original image height
  imgName,
  shouldDownload,
  shouldNotFetchAllModel,
}: queryModelReturnTensorsProps) => {
  if (!API_ENDPOINT) return;
  //if (!ALL_MASK_API_ENDPOINT) return;

  // const segRequest = blobToBase64(blob).then(base64 => fetch(`${API_ENDPOINT}`, {
  //   method: "POST",
  //   headers: {
  //     'Content-Type': 'application/json',
  //     'Authorization': `Bearer ${TOKEN}`  // Add the Authorization header
  //   },
  //   body: JSON.stringify({
  //     return_type: "embedding",
  //     image: base64
  //   }),
  // }))
  const tokenRequest = fetch(`${API_TOKEN_ENDPOINT}`, {
    method: "POST",
    headers: {
      'Content-Type': 'application/json'
    },
    body: JSON.stringify({
      username: "mlaiuser",
      password: "qfxHdRLuwhxCTJDp"
    })
  }).then(resp => {
    return resp.json()
  })

  const segRequest = tokenRequest.then(body => fetch(`${API_ENDPOINT}`, {
    method: "POST",
    headers: {
      'Authorization': `Bearer ${body.return.jwt}`  // Add the Authorization header
    },
    body: blob
  }))
  // const segRequest = fetch(`${API_ENDPOINT}`, {
  //   method: "POST",
  //   body: blob,
  // });
  segRequest.then(async (segResponse) => {
    if (shouldDownload) {
      const segResponseClone = segResponse.clone();
      const segResponseBlob = await segResponseClone.blob();
      downloadBlob(segResponseBlob, imgName);
    }
    // Get the response as an ArrayBuffer
    const arrayBuffer = await segResponse.arrayBuffer();
    // Use npyjs to parse the ArrayBuffer
    const npy = new npyjs();
    const parsedArray = npy.parse(arrayBuffer);
    // Convert the parsed data to a Float32Array
    const float32Arr = new Float32Array(parsedArray.data);

    // const segJSON = await segResponse.json();

    // const binaryString = window.atob(segJSON.data);
    // //const binaryString = window.atob(segJSON[0]);
    // const uint8arr = new Uint8Array(binaryString.length);
    // for (let i = 0; i < binaryString.length; i++) {
    //   uint8arr[i] = binaryString.charCodeAt(i);
    // }
    // const float32Arr = new Float32Array(uint8arr.buffer);
    const lowResTensor = new Tensor("float32", float32Arr, [1, 256, 64, 64]);
    // const embedArr = segJSON.map((arrStr: string) => {
    //   const binaryString = window.atob(arrStr);
    //   const uint8arr = new Uint8Array(binaryString.length);
    //   for (let i = 0; i < binaryString.length; i++) {
    //     uint8arr[i] = binaryString.charCodeAt(i);
    //   }
    //   const float32Arr = new Float32Array(uint8arr.buffer);
    //   return float32Arr;
    // });
    // const lowResTensor = new Tensor("float32", embedArr[0], [1, 256, 64, 64]);
    handleSegModelResults({
      tensor: lowResTensor,
    });
  })
  if (!shouldNotFetchAllModel) {
    const allImgName = imgName + ".all";
    const allRequest = tokenRequest.then(body => fetch(`${ALL_MASK_API_ENDPOINT}`, {
      method: "POST",
      headers: {
        'Authorization': `Bearer ${body.return.jwt}`  // Add the Authorization header
      },
      body: blob
    }))
    // const allRequest = blobToBase64(blob).then(base64 => fetch(`${ALL_MASK_API_ENDPOINT}`, {
    //   method: "POST",
    //   headers: {
    //     'Content-Type': 'application/json',
    //     'Authorization': `Bearer ${TOKEN}`  // Add the Authorization header
    //   },
    //   body: JSON.stringify({
    //     "return_mask": true,
    //     image: base64
    //   }),
    // }))
    // const allRequest =
    //   imgName && !shouldDownload
    //     ? fetch(`/assets/gallery/${allImgName}.txt`)
    //     : fetch(`${ALL_MASK_API_ENDPOINT}`, {
    //       method: "POST",
    //       body: blob,
    //     });
    // const allRequest = fetch(`${ALL_MASK_API_ENDPOINT}`, {
    //   method: "POST",
    //   body: blob,
    // });
    allRequest.then(async (allResponse) => {
      if (shouldDownload) {
        const allResponseClone = allResponse.clone();
        const allResponseBlob = await allResponseClone.blob();
        downloadBlob(allResponseBlob, allImgName);
      }
      //const allJSON = await allResponse.json()
      let allJSON: any = (await allResponse.json()).results;
      allJSON.map((seg: any) => {
        seg.point_coord = seg.point_coords[0];
        delete seg.point_coords;
        return seg
      })
      handleAllModelResults({
        allJSON,
        image_height,
      });
    });
    allRequest.catch((e) => console.log(e));
  }
};

const queryEraseModel = async ({
  image,
  mask,
  handlePredictedImage,
}: queryEraseModelProps) => {
  const [eraseResponse] = await Promise.all([
    fetch(`${ERASE_API_ENDPOINT}`, {
      method: "POST",
      headers: {
        "Content-Type": "application/json",
      },
      body: JSON.stringify({
        image: image,
        // @ts-ignore
        mask: Array.from(mask),
        dilate_kernel_size: 24,
      }),
    }),
  ]);
  const [eraseJSON] = await Promise.all([eraseResponse.text()]);
  const imgSrc = "data:image/png;base64, " + eraseJSON;
  handlePredictedImage(imgSrc);
};

const getBase64StringFromDataURL = (dataURL: string) =>
  dataURL.replace("data:", "").replace(/^.+,/, "");

const setParmsandQueryEraseModel = ({
  width,
  height,
  uploadScale,
  imgData,
  mask,
  handlePredictedImage,
}: setParmsandQueryEraseModelProps) => {
  console.log("Querying erase model");
  const canvas = document.createElement("canvas");
  canvas.width = Math.round(width * uploadScale);
  canvas.height = Math.round(height * uploadScale);
  const ctx = canvas.getContext("2d");
  if (!ctx) return;
  ctx.drawImage(imgData || new Image(), 0, 0, canvas.width, canvas.height);
  const dataURL = canvas.toDataURL();
  const b64im = getBase64StringFromDataURL(dataURL);
  queryEraseModel({
    image: b64im,
    mask,
    handlePredictedImage,
  });
};

const downloadBlob = (data: any, name: string) => {
  const blob = new Blob([data]);
  const link = document.createElement("a");
  link.download = name + ".txt";
  link.href = URL.createObjectURL(blob);
  document.body.appendChild(link);
  link.click();
  document.body.removeChild(link);
};

const getPointsFromBox = (box: modelInputProps) => {
  if (box.width === null || box.height === null) return;
  const upperLeft = { x: box.x, y: box.y };
  const bottomRight = { x: box.width, y: box.height };
  return { upperLeft, bottomRight };
};

const isFirstClick = (clicks: Array<modelInputProps>) => {
  return (
    (clicks.length === 1 &&
      (clicks[0].clickType === clickType.POSITIVE ||
        clicks[0].clickType === clickType.NEGATIVE)) ||
    (clicks.length === 2 &&
      clicks.every(
        (c) =>
          c.clickType === clickType.UPPER_LEFT ||
          c.clickType === clickType.BOTTOM_RIGHT
      ))
  );
};

const modelData = ({
  clicks,
  tensor,
  modelScale,
  point_coords,
  point_labels,
  last_pred_mask,
}: modeDataProps) => {
  const lowResTensor = tensor;
  let pointCoords;
  let pointLabels;
  let pointCoordsTensor;
  let pointLabelsTensor;
  // point_coords, point_labels params below are only truthy in text model
  if (point_coords && point_labels) {
    pointCoords = new Float32Array(4);
    pointLabels = new Float32Array(2);
    pointCoords[0] = point_coords[0][0];
    pointCoords[1] = point_coords[0][1];
    pointLabels[0] = point_labels[0]; // UPPER_LEFT
    pointCoords[2] = point_coords[1][0];
    pointCoords[3] = point_coords[1][1];
    pointLabels[1] = point_labels[1]; // BOTTOM_RIGHT
    pointCoordsTensor = new Tensor("float32", pointCoords, [1, 2, 2]);
    pointLabelsTensor = new Tensor("float32", pointLabels, [1, 2]);
  }
  // point click model check
  if (clicks) {
    let n = clicks.length;
    const clicksFromBox = clicks[0].clickType === 2 ? 2 : 0;

    // For click only input (no box) need to add an extra
    // negative point and label
    pointCoords = new Float32Array(2 * (n + 1));
    pointLabels = new Float32Array(n + 1);

    // Check if there is a box input
    if (clicksFromBox) {
      // For box model need to include the box clicks in the point
      // coordinates and also don't need to include the extra
      // negative point
      pointCoords = new Float32Array(2 * (n + clicksFromBox));
      pointLabels = new Float32Array(n + clicksFromBox);
      const {
        upperLeft,
        bottomRight,
      }: {
        upperLeft: { x: number; y: number };
        bottomRight: { x: number; y: number };
      } = getPointsFromBox(clicks[0])!;
      pointCoords = new Float32Array(2 * (n + clicksFromBox));
      pointLabels = new Float32Array(n + clicksFromBox);
      pointCoords[0] = upperLeft.x / modelScale.onnxScale;
      pointCoords[1] = upperLeft.y / modelScale.onnxScale;
      pointLabels[0] = 2.0; // UPPER_LEFT
      pointCoords[2] = bottomRight.x / modelScale.onnxScale;
      pointCoords[3] = bottomRight.y / modelScale.onnxScale;
      pointLabels[1] = 3.0; // BOTTOM_RIGHT

      last_pred_mask = null;
    }

    // Add regular clicks
    for (let i = 0; i < n; i++) {
      pointCoords[2 * (i + clicksFromBox)] = clicks[i].x / modelScale.onnxScale;
      pointCoords[2 * (i + clicksFromBox) + 1] =
        clicks[i].y / modelScale.onnxScale;
      pointLabels[i + clicksFromBox] = clicks[i].clickType;
    }

    // Add in the extra point/label when only clicks and no box
    // The extra point is at (0, 0) with label -1
    if (!clicksFromBox) {
      pointCoords[2 * n] = 0.0;
      pointCoords[2 * n + 1] = 0.0;
      pointLabels[n] = -1.0;
      // update n for creating the tensor
      n = n + 1;
    }

    // Create the tensor
    pointCoordsTensor = new Tensor("float32", pointCoords, [
      1,
      n + clicksFromBox,
      2,
    ]);
    pointLabelsTensor = new Tensor("float32", pointLabels, [
      1,
      n + clicksFromBox,
    ]);
  }
  const imageSizeTensor = new Tensor("float32", [
    modelScale.maskHeight,
    modelScale.maskWidth,
  ]);
  if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
    return;

  // if there is a previous tensor, use it, otherwise we default to an empty tensor
  const lastPredMaskTensor =
    last_pred_mask && clicks && !isFirstClick(clicks)
      ? last_pred_mask
      : new Tensor("float32", new Float32Array(256 * 256), [1, 1, 256, 256]);

  // +!! is javascript shorthand to convert truthy value to 1, falsey value to 0
  const hasLastPredTensor = new Tensor("float32", [
    +!!(last_pred_mask && clicks && !isFirstClick(clicks)),
  ]);
  // return {
  //   image_embeddings: lowResTensor,
  //   point_coords: pointCoordsTensor,
  //   point_labels: pointLabelsTensor,
  //   orig_im_size: imageSizeTensor,
  //   // orig_im_size: new Tensor("float32", [
  //   //   modelScale.height,
  //   //   modelScale.width,
  //   // ]),
  //   mask_input: lastPredMaskTensor,
  //   has_mask_input: hasLastPredTensor,
  // };
  return {
    low_res_embedding: lowResTensor,
    point_coords: pointCoordsTensor,
    point_labels: pointLabelsTensor,
    image_size: imageSizeTensor,
    last_pred_mask: lastPredMaskTensor,
    has_last_pred: hasLastPredTensor,
  };
};

export { setParmsandQueryModel, modelData, setParmsandQueryEraseModel };
