﻿#include "cLearn.h"

#define TRAININGSETIDX		0
#define VALIDATIONSETIDX	1
#define TESTSETIDX			2
#define FULLSETIDX			3

BlockSpliter::BlockSpliter(STRING strBlockFile)
{
	memset(Size, 0, sizeof(INT64)* 3);
	fi.open(strBlockFile, ios::in|ios::binary);
	Size[FULLSETIDX] = ReadINT64(&fi);
	buffer = new INT64[0x1000000];
	ResetPointer(FULLSET);
}

BlockSpliter::~BlockSpliter()
{
	fi.close();
	delete[]buffer;
}

INT32 BlockSpliter::GetBufferSize()
{
	return bufSize;
}

VOID BlockSpliter::SetBufferSize(INT64 bufSize)
{
	this->bufSize = (INT32)MIN(MAX(bufSize, 0x8000), 0x1000000);
}

VOID BlockSpliter::ResetPointer(INT32 mode) 
{
	fi.seekg(sizeof(INT64), fi.beg); //skip Count variable
	Mode = mode;
	readSize = 0;	
}

INT64 BlockSpliter::GetEntryCount()
{
	INT64 CurrentEntryCount = 0;
	if (Mode & TRAININGSET) CurrentEntryCount += Size[TRAININGSETIDX];
	if (Mode & VALIDATIONSET) CurrentEntryCount += Size[VALIDATIONSETIDX];
	if (Mode & TESTSET) CurrentEntryCount += Size[TESTSETIDX];
	if (Mode == FULLSET)
		CurrentEntryCount = Size[FULLSETIDX];
	return CurrentEntryCount;
}

INT32 BlockSpliter::Contains(INT32 srcID, INT32 Mode)
{
	if (Mode == -1)
		Mode = this->Mode;
	if (Mode == FULLSET)
		return 1;
	if ((Mode & TRAININGSET) && splitedSourceInstances[TRAININGSETIDX].find(srcID) != splitedSourceInstances[TRAININGSETIDX].end())
		return 1;
	if ((Mode & VALIDATIONSET) && splitedSourceInstances[VALIDATIONSETIDX].find(srcID) != splitedSourceInstances[VALIDATIONSETIDX].end())
		return 1;
	if ((Mode & TESTSET) && splitedSourceInstances[TESTSETIDX].find(srcID) != splitedSourceInstances[TESTSETIDX].end())
		return 1;
	return 0;
}

INT32 BlockSpliter::ContainedBy(INT32 srcID) 
{
	for (INT32 i = 0; i < 3; i++)
	if (splitedSourceInstances[i].find(srcID) != splitedSourceInstances[i].end())
		return i;
	return -1;
}

VECTOR<pair<INT32, INT32>> BlockSpliter::Read()
{
	VECTOR<pair<INT32, INT32>> res;
	res.reserve(bufSize);
	INT32 toRead = (INT32)MIN(Size[FULLSETIDX] - readSize, bufSize);
	readSize += toRead;
	fi.read((INT8*)buffer, toRead * sizeof(INT64));
	for (INT32 i = 0; i < toRead; i++)
	{
		INT32 src = (INT32)(buffer[i] & 0xFFFFFFFF);
		if (Contains(src))
			res.push_back(pair<INT32, INT32>(src, (INT32)(buffer[i] >> 32)));
	}
	return res;
}

VOID BlockSpliter::Save(STRING strFile)
{
	ofstream of(strFile, ios::out | ios::binary);
	of.write((INT8*)Size, 4 * sizeof(INT64));
	for (INT32 i = 0; i < 3; i++)
	{
		INT32 size = splitedSourceInstances[i].size();
		of.write((INT8*)&size, sizeof(INT32));
		for (auto src : splitedSourceInstances[i])
			of.write((INT8*)&src, sizeof(INT32));
	}

	of.close();
}

VOID BlockSpliter::Load(STRING strFile)
{
	INT8* buf = ReadAllBytes(strFile);
	INT8 *ptr = buf;
	for (INT32 i = 0; i < 4; i++)
	{
		Size[i] = ReadINT64(ptr);
	}
	for (INT32 i = 0; i < 3; i++)
	{
		INT32 size = ReadINT32(ptr);
		splitedSourceInstances[i] = SET<INT32>((INT32*)ptr, (INT32*)ptr + size);
		ptr += size * sizeof(INT32);
	}
	delete[]buf;
}

VOID BlockSpliter::BeginSplit(SET<INT64> refLinks, FLOAT trainingSplit, FLOAT validationSplit, FLOAT testSplit, INT32 Folds)
{
	//from refLinks
	SET<INT32> positive, negative;
	for (auto ref : refLinks)
		positive.insert((INT32)(ref & 0xFFFFFFFF));

	//from block file
	ResetPointer(FULLSET);
	while (readSize < Size[FULLSETIDX])
	{
		INT32 toRead = (INT32)MIN(Size[FULLSETIDX] - readSize, bufSize);
		readSize += toRead;
		fi.read((INT8*)buffer, toRead * sizeof(INT64));
		for (INT32 i = 0; i < toRead; i++)
		{
			INT32 src = (INT32)(buffer[i] & 0xFFFFFFFF);
			if (positive.find(src) == positive.end())
				negative.insert(src);
			entryCount[src]++;
		}
	}
	positiveSourceInstances.assign(positive.begin(), positive.end());
	negativeSourceInstances.assign(negative.begin(), negative.end());
	this->trainingSplit = trainingSplit;
	this->validationSplit = validationSplit;
	this->testSplit = testSplit;
	this->Folds = Folds;
	currentFold = 0;
	srand(time(0) + TickCount());
	if (Folds != -1)
	{
		Shuffle(positiveSourceInstances);
		Shuffle(negativeSourceInstances);
	}
}


VOID BlockSpliter::Split()
{
	for (INT32 i = 0; i < 3; i++)
	{
		splitedSourceInstances[i].clear();
		Size[i] = 0;
	}

	INT32 posSize = positiveSourceInstances.size();
	INT32 negSize = negativeSourceInstances.size();
	SET<INT32> posTrain, negTrain;
	if (Folds == -1) 
	{
		Shuffle(positiveSourceInstances);
		Shuffle(negativeSourceInstances);		
		INT32 posSeg = (INT32)(testSplit*posSize);
		INT32 negSeg = (INT32)(testSplit*negSize);
		Size[TESTSETIDX] += Distribute(&splitedSourceInstances[TESTSETIDX], &positiveSourceInstances, 0, posSeg);
		Size[TESTSETIDX] += Distribute(&splitedSourceInstances[TESTSETIDX], &negativeSourceInstances, 0, negSeg);
		Distribute(&posTrain, &positiveSourceInstances, posSeg, posSize);
		Distribute(&negTrain, &negativeSourceInstances, negSeg, negSize);
	}
	else
	{
		//Distribute source instances
		INT32 posSeg = posSize / Folds;
		INT32 negSeg = negSize / Folds;
		INT32 posBeg = currentFold * posSeg;
		INT32 negBeg = currentFold * negSeg;
		if (currentFold == Folds - 1)
		{
			posSeg = posSize - posSeg * currentFold;
			negSeg = negSize - negSeg * currentFold;
		}
		Size[TESTSETIDX] += Distribute(&splitedSourceInstances[TESTSETIDX], &positiveSourceInstances, posBeg, posBeg + posSeg);
		Size[TESTSETIDX] += Distribute(&splitedSourceInstances[TESTSETIDX], &negativeSourceInstances, negBeg, negBeg + negSeg);
		Distribute(&posTrain, &positiveSourceInstances, 0, posBeg);
		Distribute(&negTrain, &negativeSourceInstances, 0, negBeg);
		Distribute(&posTrain, &positiveSourceInstances, posBeg + posSeg, posSize);
		Distribute(&negTrain, &negativeSourceInstances, negBeg + negSeg, negSize);
		currentFold++;
	}

	VECTOR<INT32> pos(posTrain.begin(), posTrain.end());
	VECTOR<INT32> neg(negTrain.begin(), negTrain.end());
	posSize = pos.size();
	negSize = neg.size();
	INT32 posSeg = (INT32)(trainingSplit*posSize);
	INT32 negSeg = (INT32)(trainingSplit*negSize);
	Size[TRAININGSETIDX] += Distribute(&splitedSourceInstances[TRAININGSETIDX], &pos, 0, posSeg);
	Size[TRAININGSETIDX] += Distribute(&splitedSourceInstances[TRAININGSETIDX], &neg, 0, negSeg);
	Size[VALIDATIONSETIDX] += Distribute(&splitedSourceInstances[VALIDATIONSETIDX], &pos, posSeg, posSize);
	Size[VALIDATIONSETIDX] += Distribute(&splitedSourceInstances[VALIDATIONSETIDX], &neg, negSeg, negSize);
}

VOID BlockSpliter::Shuffle(VECTOR<INT32> &value)
{
	//Shuffle array using Fisher–Yates algorithm	
	for (INT32 i = value.size() - 1; i >= 1; i--)
	{
		INT32 j = rand() % (i + 1);
		INT32 t = value[i];
		value[i] = value[j];
		value[j] = t;
	}
}

INT64 BlockSpliter::Distribute(SET<INT32> *des, VECTOR<INT32> *src, INT32 nBegin, INT32 nEnd)
{
	INT64 size = 0;
	for (INT32 i = nBegin; i < nEnd; i++)
	{
		des->insert(src->at(i));
		size += entryCount[src->at(i)];
	}
	return size;
}
