blob: d1edbdbda623fe19f3d9c3af3cb236b2e6e9bb18 [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
/*
* This file include code from rnnlmlib-0.4 under BSD new license.
* Copyright (c) 2010-2012 Tomas Mikolov
* Copyright (c) 2013 Cantab Research Ltd
* All rights reserved.
*/
//
// This code creates DataShard for RNNLM dataset.
// The RNNLM dataset could be downloaded at
// http://www.rnnlm.org/
//
// Usage:
// create_shard.bin -train [train_file] -valid [valid_file]
// -test [test_file] -class_size [# of classes]
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <fstream>
#include "singa/io/store.h"
#include "singa/utils/common.h"
#include "singa/proto/common.pb.h"
#include "./rnnlm.pb.h"
#define MAX_STRING 100
#define BUFFER_LEN 32
#define NL_STRING "</s>"
using std::string;
using std::max;
using std::min;
struct vocab_word {
int cn;
char word[MAX_STRING];
int class_index;
};
struct vocab_word *vocab;
int vocab_max_size;
int vocab_size;
int *vocab_hash;
int vocab_hash_size;
int debug_mode;
int old_classes;
int *class_start;
int *class_end;
int class_size;
char train_file[MAX_STRING];
char valid_file[MAX_STRING];
char test_file[MAX_STRING];
int valid_mode;
int test_mode;
unsigned int getWordHash(char *word) {
unsigned int hash, a;
hash = 0;
for (a = 0; a < strlen(word); a++) hash = hash * 237 + word[a];
hash = hash % vocab_hash_size;
return hash;
}
int searchVocab(char *word) {
int a;
unsigned int hash;
hash = getWordHash(word);
if (vocab_hash[hash] == -1) return -1;
if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash];
for (a = 0; a < vocab_size; a++) { // search in vocabulary
if (!strcmp(word, vocab[a].word)) {
vocab_hash[hash] = a;
return a;
}
}
return -1; // return OOV if not found
}
int addWordToVocab(char *word) {
unsigned int hash;
snprintf(vocab[vocab_size].word, strlen(word)+1, "%s", word);
vocab[vocab_size].cn = 0;
vocab_size++;
if (vocab_size + 2 >= vocab_max_size) { // reallocate memory if needed
vocab_max_size += 100;
vocab = (struct vocab_word *) realloc(
vocab,
vocab_max_size * sizeof(struct vocab_word));
}
hash = getWordHash(word);
vocab_hash[hash] = vocab_size - 1;
return vocab_size - 1;
}
void readWord(char *word, FILE *fin) {
int a = 0, ch;
while (!feof(fin)) {
ch = fgetc(fin);
if (ch == 13) continue;
if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
if (a > 0) {
if (ch == '\n') ungetc(ch, fin);
break;
}
if (ch == '\n') {
snprintf(word, strlen(NL_STRING) + 1,
"%s", const_cast<char *>(NL_STRING));
return;
} else {
continue;
}
}
word[a] = static_cast<char>(ch);
a++;
if (a >= MAX_STRING) {
// printf("Too long word found!\n"); //truncate too long words
a--;
}
}
word[a] = 0;
}
void sortVocab() {
int a, b, max;
vocab_word swap;
for (a = 1; a < vocab_size; a++) {
max = a;
for (b = a + 1; b < vocab_size; b++)
if (vocab[max].cn < vocab[b].cn) max = b;
swap = vocab[max];
vocab[max] = vocab[a];
vocab[a] = swap;
}
}
int learnVocabFromTrainFile() {
char word[MAX_STRING];
FILE *fin;
int a, i, train_wcn;
for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
fin = fopen(train_file, "rb");
vocab_size = 0;
addWordToVocab(const_cast<char *>(NL_STRING));
train_wcn = 0;
while (1) {
readWord(word, fin);
if (feof(fin)) break;
train_wcn++;
i = searchVocab(word);
if (i == -1) {
a = addWordToVocab(word);
vocab[a].cn = 1;
} else {
vocab[i].cn++;
}
}
sortVocab();
if (debug_mode > 0) {
printf("Vocab size: %d\n", vocab_size);
printf("Words in train file: %d\n", train_wcn);
}
fclose(fin);
return 0;
}
int splitClasses() {
double df, dd;
int i, a, b;
df = 0;
dd = 0;
a = 0;
b = 0;
class_start = reinterpret_cast<int *>(calloc(class_size, sizeof(int)));
memset(class_start, 0x7f, sizeof(int) * class_size);
class_end = reinterpret_cast<int *>(calloc(class_size, sizeof(int)));
memset(class_end, 0, sizeof(int) * class_size);
if (old_classes) { // old classes
for (i = 0; i < vocab_size; i++)
b += vocab[i].cn;
for (i = 0; i < vocab_size; i++) {
df += vocab[i].cn / static_cast<double>(b);
if (df > 1) df = 1;
if (df > (a + 1) / static_cast<double>(class_size)) {
vocab[i].class_index = a;
if (a < class_size - 1) a++;
} else {
vocab[i].class_index = a;
}
}
} else { // new classes
for (i = 0; i < vocab_size; i++)
b += vocab[i].cn;
for (i = 0; i < vocab_size; i++)
dd += sqrt(vocab[i].cn / static_cast<double>(b));
for (i = 0; i < vocab_size; i++) {
df += sqrt(vocab[i].cn / static_cast<double>(b)) / dd;
if (df > 1) df = 1;
if (df > (a + 1) / static_cast<double>(class_size)) {
vocab[i].class_index = a;
if (a < class_size - 1) a++;
} else {
vocab[i].class_index = a;
}
}
}
// after dividing classes, update class start and class end information
for (i = 0; i < vocab_size; i++) {
a = vocab[i].class_index;
class_start[a] = min(i, class_start[a]);
class_end[a] = max(i + 1, class_end[a]);
}
return 0;
}
int init_class() {
// debug_mode = 1;
debug_mode = 0;
vocab_max_size = 100; // largest length value for each word
vocab_size = 0;
vocab = (struct vocab_word *) calloc(vocab_max_size,
sizeof(struct vocab_word));
vocab_hash_size = 100000000;
vocab_hash = reinterpret_cast<int *>(calloc(vocab_hash_size, sizeof(int)));
old_classes = 1;
// read vocab
learnVocabFromTrainFile();
// split classes
splitClasses();
return 0;
}
int create_data(const char *input_file, const char *output) {
auto* store = singa::io::OpenStore("kvfile", output, singa::io::kCreate);
WordRecord wordRecord;
FILE *fin;
int a, i;
fin = fopen(input_file, "rb");
int wcnt = 0;
char key[BUFFER_LEN];
char wordstr[MAX_STRING];
string value;
while (1) {
readWord(wordstr, fin);
if (feof(fin)) break;
i = searchVocab(wordstr);
if (i == -1) {
if (debug_mode) printf("unknown word [%s] detected!", wordstr);
} else {
wordRecord.set_word(string(wordstr));
wordRecord.set_word_index(i);
int class_idx = vocab[i].class_index;
wordRecord.set_class_index(class_idx);
wordRecord.set_class_start(class_start[class_idx]);
wordRecord.set_class_end(class_end[class_idx]);
int length = snprintf(key, BUFFER_LEN, "%05d", wcnt++);
wordRecord.SerializeToString(&value);
store->Write(string(key, length), value);
}
}
fclose(fin);
store->Flush();
delete store;
return 0;
}
int argPos(char *str, int argc, char **argv) {
int a;
for (a = 1; a < argc; a++)
if (!strcmp(str, argv[a]))
return a;
return -1;
}
int main(int argc, char **argv) {
int i;
FILE *f;
// set debug mode
i = argPos(const_cast<char *>("-debug"), argc, argv);
if (i > 0) {
debug_mode = 1;
if (debug_mode > 0)
printf("debug mode: %d\n", debug_mode);
}
// search for train file
i = argPos(const_cast<char *>("-train"), argc, argv);
if (i > 0) {
if (i + 1 == argc) {
printf("ERROR: training data file not specified!\n");
return 0;
}
snprintf(train_file, strlen(argv[i + 1])+1, "%s", argv[i + 1]);
if (debug_mode > 0)
printf("train file: %s\n", train_file);
f = fopen(train_file, "rb");
if (f == NULL) {
printf("ERROR: training data file not found!\n");
return 0;
}
fclose(f);
} else {
printf("ERROR: training data must be set.\n");
}
// search for valid file
i = argPos(const_cast<char *>("-valid"), argc, argv);
if (i > 0) {
if (i + 1 == argc) {
printf("ERROR: validating data file not specified!\n");
return 0;
}
snprintf(valid_file, strlen(argv[i + 1])+1, "%s", argv[i + 1]);
if (debug_mode > 0)
printf("valid file: %s\n", valid_file);
f = fopen(valid_file, "rb");
if (f == NULL) {
printf("ERROR: validating data file not found!\n");
return 0;
}
fclose(f);
valid_mode = 1;
}
// search for test file
i = argPos(const_cast<char *>("-test"), argc, argv);
if (i > 0) {
if (i + 1 == argc) {
printf("ERROR: testing data file not specified!\n");
return 0;
}
snprintf(test_file, strlen(argv[i + 1])+1, "%s", argv[i + 1]);
if (debug_mode > 0)
printf("test file: %s\n", test_file);
f = fopen(test_file, "rb");
if (f == NULL) {
printf("ERROR: testing data file not found!\n");
return 0;
}
fclose(f);
test_mode = 1;
}
// search for class size
i = argPos(const_cast<char *>("-class_size"), argc, argv);
if (i > 0) {
if (i + 1 == argc) {
printf("ERROR: class size not specified!\n");
return 0;
}
class_size = atoi(argv[i + 1]);
if (debug_mode > 0)
printf("class size: %d\n", class_size);
}
if (class_size <= 0) {
printf("ERROR: no or invalid class size received!\n");
return 0;
}
init_class();
create_data(train_file, "train_data.bin");
if (valid_mode) create_data(valid_file, "valid_data.bin");
if (test_mode) create_data(test_file, "test_data.bin");
return 0;
}