Commit e2bcaba0 authored by Achyudh Ram's avatar Achyudh Ram

Add IMDB dataset

parents
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import json
from sklearn.model_selection import train_test_split
import os
def get_review_text(review):
"""
Get a string containing the title and body of the review
:param article: A IMDB review dict
:return: String containing the review title and body
"""
title, body = "", ""
if review['title'] is not None:
title = ' '.join(review['title'].split()) + " "
if review['review'] is not None:
body = ' '.join(review['review'].split())
return title + body
def get_binary_label(label):
category_label = [0 for x in range(10)]
category_label[label - 1]
return ''.join(map(str, category_label))
def parse_documents():
"""
Extract the reviews from IMDB dataset and create train/dev/test splits
:return: Three lists containing the train, dev and test splits along with the labels
"""
with open(os.path.join("data", "reviews.json"), 'r') as json_file:
reviews = list()
for review in json.load(json_file):
reviews.append((get_binary_label(review['rating']), get_review_text(review)))
train_documents, test_documents = train_test_split(reviews, test_size=0.3, random_state=37)
train_documents, validation_documents = train_test_split(train_documents, test_size=0.28, random_state=53)
return train_documents, validation_documents, test_documents
if __name__ == "__main__":
train_documents, validation_documents, test_documents = parse_documents()
print("Train, dev and test dataset sizes:", len(train_documents), len(validation_documents), len(test_documents))
with open(os.path.join("data", "imdb_train.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in train_documents:
tsv_file.write(label + "\t" + document + "\n")
with open(os.path.join("data", "imdb_validation.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in validation_documents:
tsv_file.write(label + "\t" + document + "\n")
with open(os.path.join("data", "imdb_test.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in test_documents:
tsv_file.write(label + "\t" + document + "\n")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment