Skip to content
Snippets Groups Projects
Commit 2e273d3d authored by Henry Tian's avatar Henry Tian
Browse files

Upload New File

parent a811f11d
No related branches found
No related tags found
1 merge request!5Allminilm
Pipeline #108444 failed
import XCTest
import SwiftAnnoy
import NaturalLanguage
@testable import SwiftNLP
final class AllMiniLM_pipelineTest: XCTestCase {
// test fetching names of all the files
func testFileNameFetching() throws {
let redditCommentNames = TestUtils.getJsonFiles(prefix: "RC")
print("reddit comment files: \(redditCommentNames)")
let redditSubmissionNames = TestUtils.getJsonFiles(prefix: "RS")
print("reddit submission files: \(redditSubmissionNames)")
}
// test reading reddit submission json files into actual objects
func testRedditSubmissions() throws {
let redditSubmissionJson = TestUtils.loadAllRedditSubmission()
for jsonData in redditSubmissionJson {
let redditSubmission = readRedditSubmissionJson(json: jsonData)
XCTAssertNotNil(redditSubmission, "Failed to decode RedditSubmissionData")
}
}
// test reading reddit comment json files into actual objects
func testRedditComments() throws {
let redditCommentJson = TestUtils.loadAllRedditComment()
for jsonData in redditCommentJson {
let redditComment = readRedditCommentJson(json: jsonData)
XCTAssertNotNil(redditComment, "Failed to decode RedditCommentData")
}
}
func test20kDownload() async throws {
let result = try await downloadSubredditFromServer(subreddit: "StopGaming")
print("Loaded \(result.count) threads from server.")
if let random = result.randomElement() {
let (key, value) = random
print("Key: \(key), Value: \(value)")
}
XCTAssertEqual(result.count, 34829, "Failed to load subreddit data from https://reddit-top20k.cworld.ai")
}
func testDocumentReading() async throws {
// loads all json data for test documents
let redditCommentJson = TestUtils.loadAllRedditComment()
let redditSubmissionJson = TestUtils.loadAllRedditSubmission()
let redditComments = redditCommentJson.compactMap { readRedditCommentJson(json: $0)}
let redditSubmissions = redditSubmissionJson.compactMap { readRedditSubmissionJson(json: $0) }
var bodies: [String] = []
// load all the reddit comments' body as comment to the document
for comment in redditComments {
//debugPrint("Processing \(comment.posts.count) comments")
for post in comment.posts {
if let body = post.body {
bodies.append(body)
}
}
}
for submission in redditSubmissions {
//debugPrint("Processing \(submission.posts.count) submissions")
for post in submission.posts {
if let p = post.selftext {
//debugPrint(p)
bodies.append(p)
}
}
}
// Debug code
// bodies = Array(bodies.prefix(10))
// print(bodies)
//start to encode the db and query
var database_embedding: [[Float]] = []
var query_embedding: [Float] = []
let query = "stop playing video games"
var embedding_dim: Int = 384
var model = MiniLMEmbeddings()
query_embedding = await model.encode(sentence: query)!
var i = 1
//append sentence embedding to database_embedding
for string in bodies {
if let vector = await model.encode(sentence: string) {
database_embedding.append(vector)
print(i)
i += 1
} else {
fatalError("Error occurred1")
}
}
let index = AnnoyIndex<Float>(itemLength: embedding_dim, metric: .euclidean)
try? index.addItems(items: &database_embedding)
try? index.build(numTrees: 50)
let results = index.getNNsForVector(vector: &query_embedding, neighbors: 10)
if let finalresult = results {
let extractedIndeices = finalresult.indices
for index in extractedIndeices {
if index < bodies.count {
print(bodies[index])
} else {
print("Index \(index) out of range.")
}
}
}
print(results)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment