#include <iostream>
#include <stdlib.h>
//#include "Point.h"
#include "Patch.h"
#include <assert.h>

//using namespace std;

#define T unsigned char


Image *images;
unsigned int* epitome;

int width, height, owidth, oheight;
int iterations;
int npatches;
int **distance;
Patch *patches;
int *ddata;
int *hdata;
int **heap;
int **heapindex;
int *hidata;
int heapsize;
int *bx; 
int *by; 
int *pixelstart;
unsigned int **pixeltable;
double temp, temp2;
double tempstart;
int pmax, pmin;
double pmul;
int nimages;
int npixels;

unsigned int sqtable[512];
unsigned int *sq = sqtable + 255;

long long tempepitome[128 * 128];


static void initpatches()
{
  npatches = 0;
  npixels = 0;
  for(int psize = pmax; psize >= pmin; psize /= 2)
    {
      int np = (int)(owidth * oheight * pmul / (psize * psize));
      npixels += np * psize * psize;
      npatches += np;
    }

  int sum = 0;
  int distrib[nimages];
  for(int j = 0; j < nimages; j++)
    {
      sum += images[j].width * images[j].height;
      distrib[j] = sum;
    }

  bx = new int[npatches];
  by = new int[npatches];
  patches = new Patch[npatches];
  pixelstart = new int[owidth * oheight + 1];
  pixeltable = new unsigned int *[npixels];

  int patch = 0;
  for(int psize = pmax; psize >= pmin; psize /= 2)
    {
      int np = (int)(owidth * oheight * pmul / (psize * psize));
      for(int i = patch; i < patch + np; i++)
	{
	  int j;
	  int randv = rand() % sum;
	  for(j = 0; j < nimages; j++)
	    {
	      if(randv < distrib[j]) break;
	    }
	    
	  int x = rand() % (images[j].width - psize + 1);
	  int y = rand() % (images[j].height - psize + 1);
	  patches[i] = Patch(images[j], x, y, psize);
	  std::cerr << j << " " << x << " " << y << "\n";
	}
      patch += np;
    }
}


static inline float compare(Patch& p, int u, int v)
{
  int sumr = 0, sumg = 0, sumb = 0;
  for(unsigned int y = 0; y < p.size; y++)
    {
      unsigned char *eo = ((unsigned char *)epitome) + ((v + y) % oheight) * owidth * 4;
      unsigned char *po = (unsigned char *)(p.data + y * p.size);

      for(unsigned int x = 0; x < p.size; x++)
	{
	  unsigned char *ed = eo + ((u + x) % owidth) * 4;
	  unsigned char *pd = po + x * 4;
	  sumb += (ed[0] - pd[0]) * (ed[0] - pd[0]);
	  sumg += (ed[1] - pd[1]) * (ed[1] - pd[1]);
	  sumr += (ed[2] - pd[2]) * (ed[2] - pd[2]);
	}
    }
  return sqrt(((float)(sumr + sumg + sumb)) / (p.size * p.size * 3));
}

static void mapnearest()
{
  int besti;
  int fx, fy;

  std::cerr << "map2";
  float scores[owidth * oheight];
  for(int n = 0; n < npatches; n++)
    {
      std::cout << n <<"/"<<npatches <<"\n";
      double min = 1000.;
      for(int y = 0; y < oheight; y++)
	{
	  //std::cout << y <<"\n";
	  for(int x = 0; x < owidth; x++)
	    {	  
	      double score = compare(patches[n], x, y);
	      //float score = compare2(patches[n], x, y);
	      //	      if(score != score2)
	      //	std::cerr << score << " " << score2 << "\n";

	      if(score < min) 
		{
		  min = score;
		}
	      scores[x + y * owidth] = score;
	    }
	}
      //std::cerr << n << "\n";

      float sum = 0.;
      for(int i = 0; i < owidth * oheight; i++)
	{
	  float diff;
	  diff = exp((min - scores[i]) / temp2);
	  sum += diff;
	  scores[i] = sum;
	}
      float randv = (float)rand() / (float)RAND_MAX * sum;
      
      int i;
      for(i = 0; i < owidth * oheight; i++)
	{
	  //std::cerr << (distro[value] - distro[value - 1]) / sum << "\n";
	  //std::cerr << randv << " " << scores[i] << " " <<  i << "\n";
	  if(randv <= scores[i])
	    break;
	}
      
      bx[n] = i % owidth;
      by[n] = i / owidth;
      
      assert(i != owidth * oheight);
    }
  std::cerr << "map3";
  
  int pn = 0;
  pixelstart[0] = 0;
  for(int y = 0; y < oheight; y++)
    {
      for(int x = 0; x < owidth; x++)
	{
	  for(int i = 0; i < npatches; i++)
	    {
	      int px = x - bx[i];
	      if(px < 0) px += owidth;
	      if(px >= patches[i].size) continue;
	      int py = y - by[i];
	      if(py < 0) py += oheight;
	      if(py >= patches[i].size) continue;
	      pixeltable[pn] = patches[i].data + px + py * patches[i].size;
	      pn++;
	    }
	  pixelstart[x + y * owidth + 1] = pn;
	}
    }
  assert(pn == npixels);
  std::cerr << "map4";
}

static inline double scorepixel(int x, int y, int channel, int value)
{
  int p = x + y * owidth;
  int sumsq = 0;
  for(int i = pixelstart[p]; i < pixelstart[p + 1]; i++)
    {     
      sumsq += sq[((unsigned char *)pixeltable[i])[channel] - value];
    }
  if(sumsq == 0) return 0;
  return sqrt(sumsq / (pixelstart[p + 1] - pixelstart[p]));
}

static void iterate(int x, int y, int channel)
{
  //std::cerr << x <<" " << y << "\n";


  double scores[255];
  int value;
  int oldvalue = ((unsigned char *)epitome)[x * 4 + channel + y * owidth * 4];
  int min = oldvalue - (int)(5. + 123. * (temp / tempstart));
  if (min < 0) min = 0;
  int max = oldvalue + (int)(5. + 123. * (temp / tempstart));
  if(max > 255) max = 255;

  for(value = min; value <= max; value++)
    {
      scores[value] = scorepixel(x, y, channel, value);
    }

  double sum = 0;
  double distro[256];
  double ms = scores[min];
  for(int value = min; value <= max; value++)
    {
      if(scores[value] < ms) ms = scores[value];
    }
  
  for(int value = min; value <= max; value++)
    {
      double diff;

      diff = exp((ms - scores[value]) / temp);
      //std::cerr << diff << "\n";
      sum += diff;
      distro[value] = sum;
      //      std::cerr << value << " " << sum << "\n";
    }
  double randv = (double)rand() / (double)RAND_MAX * sum;
  
  for(value = min; value <= max; value++)
    {
      //std::cerr << (distro[value] - distro[value - 1]) / sum << "\n";
      if(randv <= distro[value])
	//if(scores[value] == ms)
	break;
    }
  //std::cerr << randv << " " << value << "\n";
  ((unsigned char *)epitome)[x * 4 + channel + y * owidth * 4] = value;
}




// first - tile the entire epitome with best matches... maybe overlap a tiny bit?
// then do a second tiling with additional training patches with wherever they have a best match.// decrease randomness of patch locations.


int main(int argc, char **argv)
{
  unsigned char header[54];
  unsigned char *data;

  if(argc < 8) std::cerr << "Params: width height pmax pmin pcov iterations files\n";
  owidth = atoi(argv[1]);
  oheight = atoi(argv[2]);
  pmax = atoi(argv[3]);
  pmin = atoi(argv[4]);
  pmul = atof(argv[5]);
  iterations = atoi(argv[6]);
  
  int rseed = time(NULL);
  srand(rseed);

  nimages = argc - 7;
  images = new Image[nimages];

  for(int i = -255; i <=255; i++)
    {
      sq[i] = i * i;
    }


  for(int i = 0; i < nimages; i++)
    {
      FILE *mf = fopen(argv[7 + i], "rb");
      if(!mf)
	{
	  std::cerr << "File " << argv[7 + i] << " not found.\n";
	}
      fread(header, 54, 1, mf);
      
      int width = *(int *)(header + 18);
      int height = *(int *)(header + 22);
      
      int extra = 4 - ((width * 3) % 4);
      if(extra == 4) extra = 0;
      data = new unsigned char[(width * 3 + extra) * height];
      fread(data, (width * 3 + extra) * height, 1, mf);	
      fclose(mf);

      unsigned char *image = new unsigned char[width * height * 3];
      images[i] = Image(image, width * 3, width, height);
      
      for(int y = 0; y < height; y++)
	{
	  memcpy(images[i].data + y * width * 3, data + ((height - y - 1) * (width * 3 + extra)), width * 3);
	}
      delete[] data;
    }

  data = new unsigned char[owidth * oheight * 3];

  epitome = new unsigned int[owidth * oheight];

  bzero(epitome, width * height * 4);
  initpatches();


  int iteration = 0;
  tempstart = 15.;
  double tspeed = 10.;
  double t2speed = 12.;
  for(int iter = 0; iter < iterations; iter++)
    {
      temp = tempstart * exp(- iter / tspeed);
      temp2 = 40. * exp(- iter / t2speed);
      std::cerr << iter <<"\n";
      mapnearest();
      for(int y = 0; y < oheight; y ++)
	{
	  for(int x = 0; x < owidth; x++)
	    {
	      for(int channel = 0; channel < 3; channel++)
		{
		  int px = rand() % owidth;
		  int py = rand() % oheight;
		  int c = rand() % 3;
		  iterate(px, py, c);
		}
	    }
	}
      char on[strlen(argv[7]) + 32];
      strcpy(on, argv[7]);
      sprintf(on + strlen(on) - 3, "%02d.bmp", iter);

      //      memcpy(epitome, patches[0].data, 64 * 64 * 4);

      for(int y = 0; y < oheight; y++)
	{
	  for(int x = 0; x < owidth; x++)
	    {
	      data[y * owidth * 3 + x * 3] = ((unsigned char *)epitome)[((oheight - y - 1) * owidth * 4) + x * 4];
	      data[y * owidth * 3 + x * 3 + 1] = ((unsigned char *)epitome)[((oheight - y - 1) * owidth * 4) + x * 4 + 1];
	      data[y * owidth * 3 + x * 3 + 2] = ((unsigned char *)epitome)[((oheight - y - 1) * owidth * 4) + x * 4 + 2];
	    }
	}
      FILE *of = fopen(on, "wb");
      *(int *)(header + 18) = owidth;
      *(int *)(header + 22) = oheight;
      fwrite(header, 54, 1, of);
      fwrite(data, owidth * oheight, 3, of);
      fclose(of);
    }

  delete[] patches;
  delete[] bx;
  delete[] by;
  delete[] pixelstart;
  delete[] pixeltable;
  delete[] data;
  delete[] epitome;
  delete[] images;
}

