From c0327af8cfb4550f81ef908284b12fe9b9cf6fbe Mon Sep 17 00:00:00 2001 From: Mingchung Xia <mingchung.xia@gmail.com> Date: Wed, 7 Feb 2024 22:41:04 -0500 Subject: [PATCH] Changed initializers of HNSWCorpus to match dictionary --- .../1. Data Collection/HNSWCorpus.swift | 27 ++++++++++++++----- .../SwiftNLPTests/2. Encoding/HNSWTests.swift | 15 ++++++----- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift index 0bd97e71..eb2a2b61 100644 --- a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift +++ b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift @@ -25,19 +25,33 @@ import Foundation class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { - internal var _documentEncoder: ContextFreeEncoder<Scalar> + // internal var _documentEncoder: ContextFreeEncoder<Scalar> + internal var _documentEncoder: any SNLPEncoder var zeroes: [Scalar] var count: Int { 0 } // typicalNeighbourhoodSize = 20 is a standard benchmark var encodedDocuments: DeterministicSampleVectorIndex<[Scalar]> - - init(_documentEncoder: ContextFreeEncoder<Scalar>, typicalNeighborhoodSize: Int = 20) { - self._documentEncoder = _documentEncoder - zeroes = Array(repeating: Scalar(0), count: 384) + + init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings, scalar: Scalar.Type = Double.self, + typicalNeighborhoodSize: Int = 20) { + _documentEncoder = ContextFreeEncoder(source: encoding) + zeroes = _documentEncoder.zeroes as! [Scalar] encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize) } + init(encoder: any SNLPEncoder, scalar: Scalar.Type = Double.self, typicalNeighborhoodSize: Int = 20) { + _documentEncoder = encoder + zeroes = _documentEncoder.zeroes as! [Scalar] + encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize) + } + +// init(_documentEncoder: ContextFreeEncoder<Scalar>, typicalNeighborhoodSize: Int = 20) { +// self._documentEncoder = _documentEncoder +// zeroes = Array(repeating: Scalar(0), count: 384) +// encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize) +// } + // Decodable conformance // required init(from decoder: Decoder) throws { // let container = try decoder.container(keyedBy: CodingKeys.self) @@ -48,7 +62,8 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { @inlinable func addUntokenizedDocument(_ document: String) { - encodedDocuments.insert((_documentEncoder.encodeSentence(document))) /// as! [Scalar] not needed + /// forced unwrap as! [Scalar] is needed when we use SNLPEncoder and not ContextFreeEncoder + encodedDocuments.insert((_documentEncoder.encodeSentence(document)) as! [Scalar]) } // @inlinable diff --git a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift index a679db2b..5cc736b6 100644 --- a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift +++ b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift @@ -15,8 +15,9 @@ final class HNSWTests: XCTestCase { ] // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) - let encoder = ContextFreeEncoder<Double>(source: .hnswindex) - var corpus = HNSWCorpus(_documentEncoder: encoder) + // let encoder = ContextFreeEncoder<Double>(source: .hnswindex) + // var corpus = HNSWCorpus(_documentEncoder: encoder) + var corpus = HNSWCorpus(encoding: .glove6B50d) corpus.addUntokenizedDocuments(docs) XCTAssert(corpus.encodedDocuments.base.vectors.count == 3) @@ -54,8 +55,9 @@ final class HNSWTests: XCTestCase { ] // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) - let encoder = ContextFreeEncoder<Double>(source: .hnswindex) - var corpus = HNSWCorpus(_documentEncoder: encoder) + // let encoder = ContextFreeEncoder<Double>(source: .hnswindex) + // var corpus = HNSWCorpus(_documentEncoder: encoder) + var corpus = HNSWCorpus(encoding: .glove6B50d) corpus.addUntokenizedDocuments(twentyQuotes) @@ -79,8 +81,9 @@ final class HNSWTests: XCTestCase { let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) - let encoder = ContextFreeEncoder<Double>(source: .hnswindex) - var corpus = HNSWCorpus(_documentEncoder: encoder) + // let encoder = ContextFreeEncoder<Double>(source: .hnswindex) + // var corpus = HNSWCorpus(_documentEncoder: encoder) + var corpus = HNSWCorpus(encoding: .glove6B50d) for submission in submissions { if let text = submission.selftext { -- GitLab