Between Tuesday August 20th, 5:00pm and Thursday August 22nd, 8:00am git.uwaterloo.ca will be down for an upgrade to version 10.8.7.

e2bcaba0 by Achyudh Ram

Add IMDB dataset

0 parents
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
import json
from sklearn.model_selection import train_test_split
import os
def get_review_text(review):
"""
Get a string containing the title and body of the review
:param article: A IMDB review dict
:return: String containing the review title and body
"""
title, body = "", ""
if review['title'] is not None:
title = ' '.join(review['title'].split()) + " "
if review['review'] is not None:
body = ' '.join(review['review'].split())
return title + body
def get_binary_label(label):
category_label = [0 for x in range(10)]
category_label[label - 1]
return ''.join(map(str, category_label))
def parse_documents():
"""
Extract the reviews from IMDB dataset and create train/dev/test splits
:return: Three lists containing the train, dev and test splits along with the labels
"""
with open(os.path.join("data", "reviews.json"), 'r') as json_file:
reviews = list()
for review in json.load(json_file):
reviews.append((get_binary_label(review['rating']), get_review_text(review)))
train_documents, test_documents = train_test_split(reviews, test_size=0.3, random_state=37)
train_documents, validation_documents = train_test_split(train_documents, test_size=0.28, random_state=53)
return train_documents, validation_documents, test_documents
if __name__ == "__main__":
train_documents, validation_documents, test_documents = parse_documents()
print("Train, dev and test dataset sizes:", len(train_documents), len(validation_documents), len(test_documents))
with open(os.path.join("data", "imdb_train.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in train_documents:
tsv_file.write(label + "\t" + document + "\n")
with open(os.path.join("data", "imdb_validation.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in validation_documents:
tsv_file.write(label + "\t" + document + "\n")
with open(os.path.join("data", "imdb_test.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in test_documents:
tsv_file.write(label + "\t" + document + "\n")