先贴上这两天刚出炉的C++代码。(利用 STL 偷了不少功夫,代码待优化)
Head.h
#ifndef HEAD_H
#define HEAD_H #include "D:\\LiYangGuang\\VSPRO\\MYLSH\\HashTable.h" #include <iostream>
#include <fstream>
#include <time.h>
#include <cstdlib>
#include <vector>
#include <map>
#include <set>
#include <string> using namespace std; void loadData(bool (*data)[], int n, char *filename);
void createTable(HashTable HTSet[], bool data[][], bool extDat[][n][k] );
void insert(HT HTSet[], bool (*extDat)[n][k]);
void standHash(HT HTSet[]);
void search(vector<int>& record, bool query[], HT HTSet[]);
/*int getPosition(int V[], std::string s, int N);*/ #endif
HashTable.h
#include <string>
#include <vector> enum{ k = , l = , n = , M = n}; typedef struct
{
std::string key;
std::vector<int> elem; // element's index
} bucket; struct INT
{
bool used;
int val;
struct INT * next;
INT() : used(false), val(), next(NULL){}
}; typedef struct HashTable
{
int R[k]; // k random dimensions
int RNum[k]; // random numbers little than M
//string DC; // the contents of k dimensions
std::vector<bucket> BukSet;
INT Hash2[M];
} HT;
getPosition.h
#include <string>
inline int getPosition(int V[], std::string s, int N)
{
int position = 0;
for(int col = 0; col < k; ++col)
{
position += V[col] * (s[col] - '0');
position %= M;
}
return position;
}
computeDistance.h
inline int distance(bool v1[], bool v2[], int N)
{
int d = 0;
for(int i = 0; i < N; ++i)
d += v1[i] ^ v2[i]; return d; }
main.cpp
#include "Head.h"
#include "D:\\LiYangGuang\\VSPRO\\MYLSH\\computeDistance.h"
using namespace std;
// length of sub hashtable, as well the number of elements.
const int MAX_Q = 1000; HT HTSet[l]; bool data[n][128];
bool extDat[l][n][k]; bool query[MAX_Q][128]; // set the query item to 1000. int main(int argc, char *argv)
{
/************************************************************************/
/* Firstly, create the HashTables */
/************************************************************************/
char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt";
loadData(data, n, filename);
createTable(HTSet, data, extDat);
insert(HTSet,extDat);
standHash(HTSet); /************************************************************************/
/* Secondly, start the LSH search */
/************************************************************************/ char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt";
loadData(query, MAX_Q, queryFile);
clock_t time0 = clock();
for(int qId = 0; qId < MAX_Q; ++qId)
{
vector<int> record;
clock_t timeA = clock();
search(record, query[qId], HTSet);
set<int> Dis;
for(size_t i = 0; i < record.size(); ++i)
Dis.insert(distance(data[record[i]], query[qId]));
clock_t timeB = clock();
cout << "第 " << qId + 1 << " 次查询时间:" << timeB - timeA << endl;
}
clock_t time1 = clock();
cout << "总查询时间:" << time1 - time0 << endl; return 0; }
loadData.cpp
#include <string>
#include <fstream> void loadData(bool (*data)[128], int n, char* filename)
{
std::ifstream ifs;
ifs.open(filename, std::ios::in);
for(int row = 0; row < n; ++row)
{
std::string line;
getline(ifs, line);
for(int col = 0; col < 128; ++col)
data[row][col] = (line[col] - '0') & 1;
/* std::cout << row << std::endl;*/ }
ifs.close();
}
creatTable.cpp
#include "HashTable.h"
#include <ctime> void createTable(HT HTSet[], bool data[][128], bool extDat[][n][k] )
{
srand((unsigned)time(NULL));
for(int tableNum = 0; tableNum < l; ++tableNum)
{ /* creat the ith Table;*/ for(int randNum = 0; randNum < k; ++randNum)
{
HTSet[tableNum].R[randNum] = rand() % 128;
HTSet[tableNum].RNum[randNum] = rand() % M; for(int item = 0; item < n; ++item)
{
extDat[tableNum][item][randNum] =
data[item][HTSet[tableNum].R[randNum]];
}
}
}
}
insertData.cpp
#include "HashTable.h"
#include <iostream>
#include <map>
using namespace std; map<string, int> deRepeat;
bool equal(bool V[], bool V2[], int n)
{
int i = 0;
while(i < n)
{
if(V[i] != V2[i])
return false;
}
return true;
} string itoa(bool *v, int n, string s)
{
for(int i = 0; i < n; ++i)
s.push_back(v[i]+'0');
return s;
} void insert(HT HTSet[], bool (*extDat)[n][k])
{
for(int t = 0; t < l; ++ t) /* t: table */
{
int bktNum = 0;
bucket bkt;
bkt.key = string(itoa(extDat[t][0], k, string("")));
bkt.elem.push_back(0);
HTSet[t].BukSet.push_back(bkt);
deRepeat.insert(make_pair(bkt.key, bktNum++)); // 0 为 bucket 的位置
for(int item = 1; item < n; ++item)
{
cout << item << endl;
string key = itoa(extDat[t][item], k, string(""));
//map<string, int>::iterator it = deRepeat.find(key);
if(deRepeat.find(key) != deRepeat.end())
{
HTSet[t].BukSet[deRepeat.find(key)->second].elem.push_back(item);
cout << "exist" << endl;
}
else{
bucket bkt2;
bkt2.key = key;
bkt2.elem.push_back(item);
HTSet[t].BukSet.push_back(bkt2);
deRepeat.insert(make_pair(bkt2.key, bktNum++));
cout << "creat" << endl;
}
}
deRepeat.clear();
}
}
standHash.cpp
#include "HashTable.h"
#include <iostream>
#include "getPosition.h" void standHash(HT HTSet[])
{
for(int t = 0; t < l; ++t)
{
int BktLen = HTSet[t].BukSet.size();
for(int b = 0; b < BktLen; ++b)
{
int position = getPosition(HTSet[t].RNum, HTSet[t].BukSet[b].key, k);
INT *pIn = &HTSet[t].Hash2[position];
while(pIn->used && pIn->next != NULL)
pIn = pIn->next;
if(pIn->used){
pIn->next = new INT;
pIn->next->val = b;
pIn->next->used = true;
}else{
pIn->val = b;
pIn->used = true;
}
}
std::cout << "the " << t << "th HashTable has been finished." << std::endl;
}
}
search.cpp
#include "HashTable.h"
#include "getPosition.h"
#include <vector>
using namespace std; void search(vector<int>& record, bool query[128], HT HTSet[])
{
for(int t = 0; t < l; ++t)
{
string temKey;
int temPos = 0;
for(int c = 0; c < k; ++c)
temKey.push_back(query[HTSet[t].R[c]] + '0');
temPos = getPosition(HTSet[t].RNum, temKey, k);
vector<int> bktId;
INT *p = &HTSet[t].Hash2[temPos];
while(p != NULL && p->used)
{
bktId.push_back(p->val);
p = p->next;
}
for(size_t i = 0; i < bktId.size(); ++i)
{
bucket temB = HTSet[t].BukSet[bktId[i]];
if(temKey == temB.key)
{
for(size_t j = 0; j < temB.elem.size(); ++j)
record.push_back(temB.elem[j]);
}
}
}
}
稍后总结。
代码调整:
main.cpp
#include "Head.h"
#include "D:\\LiYangGuang\\VSPRO\\MYLSH\\MYLSH\\computeDistance.h"
using namespace std;
#pragma warning(disable: 4996)
// length of sub hashtable, as well the number of elements.
const int MAX_Q = 1000; HT HTSet[l]; bool data[n][128];
bool extDat[l][n][k]; bool query[MAX_Q][128]; // set the query item to 1000. void getFileName(int v, char *FileName)
{
itoa(v, FileName, 10);
strcat(FileName, ".txt");
} int main(int argc, char *argv)
{
/************************************************************************/
/* Firstly, create the HashTables */
/************************************************************************/
char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt";
loadData(data, n, filename);
createTable(HTSet, data, extDat);
insert(HTSet,extDat);
standHash(HTSet); char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt";
loadData(query, MAX_Q, queryFile);
/************************************************************************/
/* Secondly, start the linear Search */
// /************************************************************************/
//
// vector<RECORD> record2;
// clock_t LineTime1 = clock();
// for(int qId = 0; qId < MAX_Q; ++qId)
// {
// for(int i = 0; i < n; ++i)
// {
// RECORD tem;
// tem.Id = i;
// tem.Dis = distance(data[i], query[qId]);
// record2.push_back(tem);
// }
// record2.clear();
// }
// clock_t LineTime2 = clock();
// float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
// cout << "全部线性查询时间:" << LineTime << " s," << " 合"
// << LineTime / 60 << " minutes."<< endl;
//
// /************************************************************************/
// /* Thirdly, start the LSH search */
// /************************************************************************/
//
// clock_t time0 = clock();
// ofstream ofs;
// char outFileName[10] = { '\0'};
// int K = 1; /// define KNN
// getFileName(K, outFileName);
// ofs.out(outFileName);
//
// for(int qId = 0; qId < MAX_Q; ++qId)
// {
// vector<RECORD> record;
// clock_t timeA = clock();
// search(record, query[qId], HTSet, data);
// if(getkNN(record,K))
// clock_t timeB = clock();
// record.clear();
// cout << "第 " << qId + 1 << " 次查询时间:" <<
// (float)(timeB - timeA) / CLOCKS_PER_SEC << " s" << endl;
// }
// clock_t time1 = clock();
// cout << "总查询时间:" << (float)(time1 - time0) / CLOCKS_PER_SEC
// << " s." << endl;
/************************************************************************/
/* */
/************************************************************************/
ofstream ofs;
char outFileName[10] = { '\0'};
int K = 1; /// define KNN
getFileName(K, outFileName);
ofs.open(outFileName, ios::out);
//ofs.precision(3);
float TotalLinearTime, TotalLSHTime;
TotalLinearTime = TotalLSHTime = 0; float TotalError = 0;
int TotalMiss = 0; vector<RECORD> record2;
for(int qId = 0; qId < MAX_Q; ++qId)
{
cout << "第 " << qId << " 次查询" << endl;
clock_t LineTime1 = clock();
for(int i = 0; i < n; ++i)
{
RECORD tem;
tem.Id = i;
tem.Dis = computeDistance(data[i], query[qId], 128);
record2.push_back(tem);
}
getkNN(record2); // 利用其对距离排序
clock_t LineTime2 = clock();
float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
TotalLinearTime += LineTime; /************************************************************************/
/* Thirdly, start the LSH search */
/************************************************************************/ vector<RECORD> record;
clock_t timeA = clock();
search(record, query[qId], HTSet, data);
if(!getkNN(record, K))
{
float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
TotalLSHTime += queryTime;
ofs << "Miss\t" << "LSH Time: " << queryTime
<< "s\tLinear time: " << LineTime << 's' << endl;
TotalMiss += 1;
}
else{
float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
TotalLSHTime += queryTime;
float error = 0;
if(record[K-1].Dis == 0)
error = 1;
else
error = (float)record2[K-1].Dis / record[K-1].Dis;
ofs << "Error: " << error << "\tLSH Time: "
<< queryTime << "s\tLinear time: " << LineTime << 's' << endl;
TotalError += error; }
record.clear();
record2.clear();
}
ofs << "Average errror: " << TotalError / 817 << endl;//recitfy
ofs << "Miss ratio: " << TotalMiss / MAX_Q << endl;
ofs << "Total query time: " << "LSH, " << TotalLSHTime / 3600 << " h; "
<< "Linear, " << TotalLinearTime / 3600 << " h." << endl;
ofs.close(); return 0; }
computeDistance.h
inline int computeDistance(bool v1[], bool v2[], int N)
{
int d = 0;
for(int i = 0; i < N; ++i)
d += v1[i] ^ v2[i]; return d; }
Search.cpp
#include "HashTable.h"
#include "getPosition.h"
#include "computeDistance.h"
#include <vector>
using namespace std; /*** 加入 data 项是为了计算距离 ***/
void search(vector<RECORD>& record, bool query[128], HT HTSet[], bool data[][128])
{
for(int t = 0; t < l; ++t)
{
string temKey;
int temPos = 0;
for(int c = 0; c < k; ++c)
temKey.push_back(query[HTSet[t].R[c]] + '0');
temPos = getPosition(HTSet[t].RNum, temKey, k);
vector<int> bktId;
INT *p = &HTSet[t].Hash2[temPos];
while(p != NULL && p->used)
{
bktId.push_back(p->val);
p = p->next;
}
for(size_t i = 0; i < bktId.size(); ++i)
{
bucket temB = HTSet[t].BukSet[bktId[i]];
if(temKey == temB.key)
{
for(size_t j = 0; j < temB.elem.size(); ++j)
{
RECORD temp;
temp.Id = temB.elem[j];
temp.Dis = computeDistance(data[temp.Id], query, 128);
record.push_back(temp);
} }
}
}
}
相关截图: