Commit e2bcaba0 authored by Achyudh Ram's avatar Achyudh Ram

Add IMDB dataset

parents
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import json
from sklearn.model_selection import train_test_split
import os
def get_review_text(review):
"""
Get a string containing the title and body of the review
:param article: A IMDB review dict
:return: String containing the review title and body
"""
title, body = "", ""
if review['title'] is not None:
title = ' '.join(review['title'].split()) + " "
if review['review'] is not None:
body = ' '.join(review['review'].split())
return title + body
def get_binary_label(label):
category_label = [0 for x in range(10)]
category_label[label - 1]
return ''.join(map(str, category_label))
def parse_documents():
"""
Extract the reviews from IMDB dataset and create train/dev/test splits
:return: Three lists containing the train, dev and test splits along with the labels
"""
with open(os.path.join("data", "reviews.json"), 'r') as json_file:
reviews = list()
for review in json.load(json_file):
reviews.append((get_binary_label(review['rating']), get_review_text(review)))
train_documents, test_documents = train_test_split(reviews, test_size=0.3, random_state=37)
train_documents, validation_documents = train_test_split(train_documents, test_size=0.28, random_state=53)
return train_documents, validation_documents, test_documents
if __name__ == "__main__":
train_documents, validation_documents, test_documents = parse_documents()
print("Train, dev and test dataset sizes:", len(train_documents), len(validation_documents), len(test_documents))
with open(os.path.join("data", "imdb_train.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in train_documents:
tsv_file.write(label + "\t" + document + "\n")
with open(os.path.join("data", "imdb_validation.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in validation_documents:
tsv_file.write(label + "\t" + document + "\n")
with open(os.path.join("data", "imdb_test.tsv"), 'w', encoding='utf8') as tsv_file:
for label, document in test_documents:
tsv_file.write(label + "\t" + document + "\n")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment