Skip to content
Snippets Groups Projects

Allminilm

Merged Henry Tian requested to merge AllMiniLM into main
1 file
+ 123
0
Compare changes
  • Side-by-side
  • Inline
 
import XCTest
 
import SwiftAnnoy
 
import NaturalLanguage
 
@testable import SwiftNLP
 
 
final class AllMiniLM_pipelineTest: XCTestCase {
 
 
// test fetching names of all the files
 
func testFileNameFetching() throws {
 
let redditCommentNames = TestUtils.getJsonFiles(prefix: "RC")
 
print("reddit comment files: \(redditCommentNames)")
 
let redditSubmissionNames = TestUtils.getJsonFiles(prefix: "RS")
 
print("reddit submission files: \(redditSubmissionNames)")
 
}
 
 
// test reading reddit submission json files into actual objects
 
func testRedditSubmissions() throws {
 
let redditSubmissionJson = TestUtils.loadAllRedditSubmission()
 
for jsonData in redditSubmissionJson {
 
let redditSubmission = readRedditSubmissionJson(json: jsonData)
 
XCTAssertNotNil(redditSubmission, "Failed to decode RedditSubmissionData")
 
}
 
}
 
 
// test reading reddit comment json files into actual objects
 
func testRedditComments() throws {
 
let redditCommentJson = TestUtils.loadAllRedditComment()
 
for jsonData in redditCommentJson {
 
let redditComment = readRedditCommentJson(json: jsonData)
 
XCTAssertNotNil(redditComment, "Failed to decode RedditCommentData")
 
}
 
}
 
 
func test20kDownload() async throws {
 
 
let result = try await downloadSubredditFromServer(subreddit: "StopGaming")
 
print("Loaded \(result.count) threads from server.")
 
if let random = result.randomElement() {
 
let (key, value) = random
 
print("Key: \(key), Value: \(value)")
 
}
 
XCTAssertEqual(result.count, 34829, "Failed to load subreddit data from https://reddit-top20k.cworld.ai")
 
 
}
 
 
 
func testDocumentReading() async throws {
 
// loads all json data for test documents
 
let redditCommentJson = TestUtils.loadAllRedditComment()
 
let redditSubmissionJson = TestUtils.loadAllRedditSubmission()
 
 
let redditComments = redditCommentJson.compactMap { readRedditCommentJson(json: $0)}
 
let redditSubmissions = redditSubmissionJson.compactMap { readRedditSubmissionJson(json: $0) }
 
 
var bodies: [String] = []
 
 
// load all the reddit comments' body as comment to the document
 
for comment in redditComments {
 
//debugPrint("Processing \(comment.posts.count) comments")
 
 
for post in comment.posts {
 
if let body = post.body {
 
bodies.append(body)
 
}
 
}
 
}
 
 
for submission in redditSubmissions {
 
//debugPrint("Processing \(submission.posts.count) submissions")
 
 
for post in submission.posts {
 
if let p = post.selftext {
 
//debugPrint(p)
 
bodies.append(p)
 
}
 
}
 
}
 
 
// Debug code
 
// bodies = Array(bodies.prefix(10))
 
// print(bodies)
 
 
//start to encode the db and query
 
var database_embedding: [[Float]] = []
 
var query_embedding: [Float] = []
 
let query = "stop playing video games"
 
var embedding_dim: Int = 384
 
var model = MiniLMEmbeddings()
 
query_embedding = await model.encode(sentence: query)!
 
 
var i = 1
 
//append sentence embedding to database_embedding
 
for string in bodies {
 
if let vector = await model.encode(sentence: string) {
 
database_embedding.append(vector)
 
print(i)
 
i += 1
 
} else {
 
fatalError("Error occurred1")
 
}
 
 
}
 
 
let index = AnnoyIndex<Float>(itemLength: embedding_dim, metric: .euclidean)
 
 
try? index.addItems(items: &database_embedding)
 
try? index.build(numTrees: 50)
 
 
let results = index.getNNsForVector(vector: &query_embedding, neighbors: 10)
 
 
if let finalresult = results {
 
let extractedIndeices = finalresult.indices
 
for index in extractedIndeices {
 
if index < bodies.count {
 
print(bodies[index])
 
} else {
 
print("Index \(index) out of range.")
 
}
 
}
 
}
 
print(results)
 
}
 
}
Loading