diff --git a/Tests/SwiftNLPTests/AllMiniLM_pipelineTest.swift b/Tests/SwiftNLPTests/AllMiniLM_pipelineTest.swift new file mode 100644 index 0000000000000000000000000000000000000000..7b43238e8c62675857e2b268e322cc1b7110aaa9 --- /dev/null +++ b/Tests/SwiftNLPTests/AllMiniLM_pipelineTest.swift @@ -0,0 +1,123 @@ +import XCTest +import SwiftAnnoy +import NaturalLanguage +@testable import SwiftNLP + +final class AllMiniLM_pipelineTest: XCTestCase { + + // test fetching names of all the files + func testFileNameFetching() throws { + let redditCommentNames = TestUtils.getJsonFiles(prefix: "RC") + print("reddit comment files: \(redditCommentNames)") + let redditSubmissionNames = TestUtils.getJsonFiles(prefix: "RS") + print("reddit submission files: \(redditSubmissionNames)") + } + + // test reading reddit submission json files into actual objects + func testRedditSubmissions() throws { + let redditSubmissionJson = TestUtils.loadAllRedditSubmission() + for jsonData in redditSubmissionJson { + let redditSubmission = readRedditSubmissionJson(json: jsonData) + XCTAssertNotNil(redditSubmission, "Failed to decode RedditSubmissionData") + } + } + + // test reading reddit comment json files into actual objects + func testRedditComments() throws { + let redditCommentJson = TestUtils.loadAllRedditComment() + for jsonData in redditCommentJson { + let redditComment = readRedditCommentJson(json: jsonData) + XCTAssertNotNil(redditComment, "Failed to decode RedditCommentData") + } + } + + func test20kDownload() async throws { + + let result = try await downloadSubredditFromServer(subreddit: "StopGaming") + print("Loaded \(result.count) threads from server.") + if let random = result.randomElement() { + let (key, value) = random + print("Key: \(key), Value: \(value)") + } + XCTAssertEqual(result.count, 34829, "Failed to load subreddit data from https://reddit-top20k.cworld.ai") + + } + + + func testDocumentReading() async throws { + // loads all json data for test documents + let redditCommentJson = TestUtils.loadAllRedditComment() + let redditSubmissionJson = TestUtils.loadAllRedditSubmission() + + let redditComments = redditCommentJson.compactMap { readRedditCommentJson(json: $0)} + let redditSubmissions = redditSubmissionJson.compactMap { readRedditSubmissionJson(json: $0) } + + var bodies: [String] = [] + + // load all the reddit comments' body as comment to the document + for comment in redditComments { + //debugPrint("Processing \(comment.posts.count) comments") + + for post in comment.posts { + if let body = post.body { + bodies.append(body) + } + } + } + + for submission in redditSubmissions { + //debugPrint("Processing \(submission.posts.count) submissions") + + for post in submission.posts { + if let p = post.selftext { + //debugPrint(p) + bodies.append(p) + } + } + } + + // Debug code +// bodies = Array(bodies.prefix(10)) +// print(bodies) + + //start to encode the db and query + var database_embedding: [[Float]] = [] + var query_embedding: [Float] = [] + let query = "stop playing video games" + var embedding_dim: Int = 384 + var model = MiniLMEmbeddings() + query_embedding = await model.encode(sentence: query)! + + var i = 1 + //append sentence embedding to database_embedding + for string in bodies { + if let vector = await model.encode(sentence: string) { + database_embedding.append(vector) + print(i) + i += 1 + } else { + fatalError("Error occurred1") + } + + } + + let index = AnnoyIndex<Float>(itemLength: embedding_dim, metric: .euclidean) + + try? index.addItems(items: &database_embedding) + try? index.build(numTrees: 50) + + let results = index.getNNsForVector(vector: &query_embedding, neighbors: 10) + + if let finalresult = results { + let extractedIndeices = finalresult.indices + for index in extractedIndeices { + if index < bodies.count { + print(bodies[index]) + } else { + print("Index \(index) out of range.") + } + } + } + print(results) + } +}