Skip to content
Snippets Groups Projects
Commit c0327af8 authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Changed initializers of HNSWCorpus to match dictionary

parent f95be832
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
...@@ -25,19 +25,33 @@ import Foundation ...@@ -25,19 +25,33 @@ import Foundation
class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
internal var _documentEncoder: ContextFreeEncoder<Scalar> // internal var _documentEncoder: ContextFreeEncoder<Scalar>
internal var _documentEncoder: any SNLPEncoder
var zeroes: [Scalar] var zeroes: [Scalar]
var count: Int { 0 } var count: Int { 0 }
// typicalNeighbourhoodSize = 20 is a standard benchmark // typicalNeighbourhoodSize = 20 is a standard benchmark
var encodedDocuments: DeterministicSampleVectorIndex<[Scalar]> var encodedDocuments: DeterministicSampleVectorIndex<[Scalar]>
init(_documentEncoder: ContextFreeEncoder<Scalar>, typicalNeighborhoodSize: Int = 20) { init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings, scalar: Scalar.Type = Double.self,
self._documentEncoder = _documentEncoder typicalNeighborhoodSize: Int = 20) {
zeroes = Array(repeating: Scalar(0), count: 384) _documentEncoder = ContextFreeEncoder(source: encoding)
zeroes = _documentEncoder.zeroes as! [Scalar]
encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize) encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
} }
init(encoder: any SNLPEncoder, scalar: Scalar.Type = Double.self, typicalNeighborhoodSize: Int = 20) {
_documentEncoder = encoder
zeroes = _documentEncoder.zeroes as! [Scalar]
encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
}
// init(_documentEncoder: ContextFreeEncoder<Scalar>, typicalNeighborhoodSize: Int = 20) {
// self._documentEncoder = _documentEncoder
// zeroes = Array(repeating: Scalar(0), count: 384)
// encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
// }
// Decodable conformance // Decodable conformance
// required init(from decoder: Decoder) throws { // required init(from decoder: Decoder) throws {
// let container = try decoder.container(keyedBy: CodingKeys.self) // let container = try decoder.container(keyedBy: CodingKeys.self)
...@@ -48,7 +62,8 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { ...@@ -48,7 +62,8 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
@inlinable @inlinable
func addUntokenizedDocument(_ document: String) { func addUntokenizedDocument(_ document: String) {
encodedDocuments.insert((_documentEncoder.encodeSentence(document))) /// as! [Scalar] not needed /// forced unwrap as! [Scalar] is needed when we use SNLPEncoder and not ContextFreeEncoder
encodedDocuments.insert((_documentEncoder.encodeSentence(document)) as! [Scalar])
} }
// @inlinable // @inlinable
......
...@@ -15,8 +15,9 @@ final class HNSWTests: XCTestCase { ...@@ -15,8 +15,9 @@ final class HNSWTests: XCTestCase {
] ]
// let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
let encoder = ContextFreeEncoder<Double>(source: .hnswindex) // let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
var corpus = HNSWCorpus(_documentEncoder: encoder) // var corpus = HNSWCorpus(_documentEncoder: encoder)
var corpus = HNSWCorpus(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
XCTAssert(corpus.encodedDocuments.base.vectors.count == 3) XCTAssert(corpus.encodedDocuments.base.vectors.count == 3)
...@@ -54,8 +55,9 @@ final class HNSWTests: XCTestCase { ...@@ -54,8 +55,9 @@ final class HNSWTests: XCTestCase {
] ]
// let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
let encoder = ContextFreeEncoder<Double>(source: .hnswindex) // let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
var corpus = HNSWCorpus(_documentEncoder: encoder) // var corpus = HNSWCorpus(_documentEncoder: encoder)
var corpus = HNSWCorpus(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(twentyQuotes) corpus.addUntokenizedDocuments(twentyQuotes)
...@@ -79,8 +81,9 @@ final class HNSWTests: XCTestCase { ...@@ -79,8 +81,9 @@ final class HNSWTests: XCTestCase {
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
// let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
let encoder = ContextFreeEncoder<Double>(source: .hnswindex) // let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
var corpus = HNSWCorpus(_documentEncoder: encoder) // var corpus = HNSWCorpus(_documentEncoder: encoder)
var corpus = HNSWCorpus(encoding: .glove6B50d)
for submission in submissions { for submission in submissions {
if let text = submission.selftext { if let text = submission.selftext {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment