Skip to content
Snippets Groups Projects
Commit 6cab2aee authored by Jim Wallace's avatar Jim Wallace
Browse files

Cleaning up SNLPCorpus and SNLPEncoder and implementing types

parent 5d7f1b08
No related branches found
No related tags found
1 merge request!14- Improved ergonomics for generic types: SNLPCorpus, SNLPEncoder, InMemoryCorpus
Pipeline #115734 passed with warnings
Showing
with 146 additions and 89 deletions
...@@ -25,22 +25,26 @@ import Foundation ...@@ -25,22 +25,26 @@ import Foundation
protocol SNLPCorpus<Scalar>: Collection { protocol SNLPCorpus<Scalar>: Collection {
associatedtype Scalar: BinaryFloatingPoint & Codable associatedtype Scalar: BinaryFloatingPoint
associatedtype Encoder: SNLPEncoder where Encoder.Scalar == Scalar
associatedtype Item: SNLPDataItem
var zeroes: [Scalar] { get } var zeroes: [Scalar] { get }
var count: Int { get } var count: Int { get }
mutating func addUntokenizedDocument(_ document: String) mutating func addUntokenizedDocument(_ document: Item)
mutating func addUntokenizedDocuments(_ documents: [String]) mutating func addUntokenizedDocuments(_ documents: [Item])
} }
extension SNLPCorpus { extension SNLPCorpus {
/** /**
Adds a series of untokenized documents to the corpus, using default tokenization and text processing Adds a series of untokenized documents to the corpus, using default tokenization and text processing
*/ */
@inlinable @inlinable
mutating func addUntokenizedDocuments(_ documents: [String]) { mutating func addUntokenizedDocuments(_ documents: [Item]) {
for d in documents { for d in documents {
addUntokenizedDocument(d) addUntokenizedDocument(d)
} }
......
...@@ -24,11 +24,12 @@ ...@@ -24,11 +24,12 @@
import Foundation import Foundation
protocol SNLPEncoder<Scalar>: Codable { protocol SNLPEncoder<Scalar> {
associatedtype Scalar: BinaryFloatingPoint & Codable associatedtype Scalar: BinaryFloatingPoint
var zeroes: [Scalar] { get } var zeroes: [Scalar] { get }
var dimensions: UInt { get }
@inlinable @inlinable
func encodeToken(_ token: String) -> [Scalar] func encodeToken(_ token: String) -> [Scalar]
......
// Copyright (c) 2024 Jim Wallace
//
// Permission is hereby granted, free of charge, to any person
// obtaining a copy of this software and associated documentation
// files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use,
// copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
import Foundation
protocol SNLPSearchableCorpus: SNLPCorpus {
func isTrained() -> Bool
func train()
func searchFor(_ query: String) -> [String]
}
...@@ -23,26 +23,26 @@ ...@@ -23,26 +23,26 @@
import Foundation import Foundation
class DictionaryCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { final class DictionaryCorpus<Scalar: BinaryFloatingPoint, Encoder: SNLPEncoder, Item: SNLPDataItem>: SNLPCorpus where Encoder.Scalar == Scalar {
private var _documentEncoder: any SNLPEncoder<Scalar> internal var _documentEncoder: Encoder
var zeroes: [Scalar] { _documentEncoder.zeroes } var zeroes: [Scalar] { _documentEncoder.zeroes }
var encodedDocuments: [Int : [Scalar] ] = [:] var encodedDocuments: [Int : [Scalar] ] = [:]
var count: Int { encodedDocuments.count } var count: Int { encodedDocuments.count }
init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings, scalar: Scalar.Type = Double.self) { init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings) {
_documentEncoder = ContextFreeEncoder(source: encoding) _documentEncoder = ContextFreeEncoder<Scalar>(source: encoding) as! Encoder
} }
init(encoder: any SNLPEncoder<Scalar>, scalar: Scalar.Type = Double.self) { init(encoder: Encoder) {
_documentEncoder = encoder _documentEncoder = encoder
} }
@inlinable @inlinable
func addUntokenizedDocument(_ document: String) { func addUntokenizedDocument(_ document: Item) {
encodedDocuments[ encodedDocuments.count ] = (_documentEncoder.encodeSentence(document) ) encodedDocuments[ encodedDocuments.count ] = (_documentEncoder.encodeSentence(document.fullText) )
} }
} }
...@@ -84,7 +84,7 @@ extension DurableHNSWCorpus { ...@@ -84,7 +84,7 @@ extension DurableHNSWCorpus {
} }
} }
static func readDictionaryFromDownloads(fileName: String, width: Int = 50) -> HNSWDictionary { static func readDictionaryFromDownloads(fileName: String, dimensions: Int = 50) -> HNSWDictionary {
guard let downloadsURL = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first else { guard let downloadsURL = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first else {
print("Could not find Downloads directory") print("Could not find Downloads directory")
return [:] return [:]
...@@ -92,12 +92,12 @@ extension DurableHNSWCorpus { ...@@ -92,12 +92,12 @@ extension DurableHNSWCorpus {
let fileURL = downloadsURL.appendingPathComponent(fileName) let fileURL = downloadsURL.appendingPathComponent(fileName)
return readDictionaryMemoryMap(fileURL, width: width) return readDictionaryMemoryMap(fileURL, dimensions: dimensions)
} }
/// Width is the number of dimensions of the glove encoding /// Width is the number of dimensions of the glove encoding
// TODO: Improve this to not need to take in a width, rather switch between the encoding / encoder // TODO: Improve this to not need to take in a width, rather switch between the encoding / encoder
static func readDictionaryMemoryMap(_ url: URL, width: Int = 50) -> HNSWDictionary { static func readDictionaryMemoryMap(_ url: URL, dimensions: Int = 50) -> HNSWDictionary {
var dictionary = HNSWDictionary() var dictionary = HNSWDictionary()
do { do {
......
...@@ -81,9 +81,9 @@ extension EphemeralHNSWCorpus.DocumentVectorPair: Codable where Scalar: Codable ...@@ -81,9 +81,9 @@ extension EphemeralHNSWCorpus.DocumentVectorPair: Codable where Scalar: Codable
vector = try container.decode([Scalar].self, forKey: .vector) vector = try container.decode([Scalar].self, forKey: .vector)
} }
internal func encode(to encoder: Encoder) throws { // internal func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self) // var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(untokenizedDocument, forKey: .untokenizedDocument) // try container.encode(untokenizedDocument, forKey: .untokenizedDocument)
try container.encode(vector, forKey: .vector) // try container.encode(vector, forKey: .vector)
} // }
} }
...@@ -34,11 +34,12 @@ import Foundation ...@@ -34,11 +34,12 @@ import Foundation
// MARK: Allow EphemeralHNSWCorpus to simply be used as HNSWCorpus // MARK: Allow EphemeralHNSWCorpus to simply be used as HNSWCorpus
typealias HNSWCorpus = EphemeralHNSWCorpus typealias HNSWCorpus = EphemeralHNSWCorpus
final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable, Encoder: SNLPEncoder, Item: SNLPDataItem>: SNLPCorpus where Encoder.Scalar == Scalar {
public typealias HNSWDictionary = [Int: DocumentVectorPair] public typealias HNSWDictionary = [Int: DocumentVectorPair]
internal var _documentEncoder: any SNLPEncoder<Scalar> internal var _documentEncoder: Encoder
var zeroes: [Scalar] { _documentEncoder.zeroes } var zeroes: [Scalar] { _documentEncoder.zeroes }
var encodedDocuments: DeterministicEphemeralVectorIndex<[Scalar]> var encodedDocuments: DeterministicEphemeralVectorIndex<[Scalar]>
...@@ -48,13 +49,13 @@ final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorp ...@@ -48,13 +49,13 @@ final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorp
var dictionary: HNSWDictionary = [:] var dictionary: HNSWDictionary = [:]
// typicalNeighbourhoodSize = 20 is a standard benchmark // typicalNeighbourhoodSize = 20 is a standard benchmark
init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings, scalar: Scalar.Type = Double.self, init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings,
typicalNeighborhoodSize: Int = 20) { typicalNeighborhoodSize: Int = 20) {
_documentEncoder = ContextFreeEncoder(source: encoding) _documentEncoder = ContextFreeEncoder(source: encoding) as! Encoder
encodedDocuments = DeterministicEphemeralVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize) encodedDocuments = DeterministicEphemeralVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
} }
init(encoder: any SNLPEncoder<Scalar>, scalar: Scalar.Type = Double.self, typicalNeighborhoodSize: Int = 20) { init(encoder: Encoder, typicalNeighborhoodSize: Int = 20) {
_documentEncoder = encoder _documentEncoder = encoder
encodedDocuments = DeterministicEphemeralVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize) encodedDocuments = DeterministicEphemeralVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
} }
...@@ -68,13 +69,13 @@ final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorp ...@@ -68,13 +69,13 @@ final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorp
// } // }
@inlinable @inlinable
func addUntokenizedDocument(_ document: String) { func addUntokenizedDocument(_ document: Item) {
/// forced unwrap as! [Scalar] is needed when we use SNLPEncoder but not ContextFreeEncoder /// forced unwrap as! [Scalar] is needed when we use SNLPEncoder but not ContextFreeEncoder
/// encodedDocuments.insert will insert and return the corresponding key (id) /// encodedDocuments.insert will insert and return the corresponding key (id)
let key = encodedDocuments.insert((_documentEncoder.encodeSentence(document)) ) let key = encodedDocuments.insert((_documentEncoder.encodeSentence(document.fullText)) )
addDocumentVectorPair( addDocumentVectorPair(
at: key, at: key,
document: document, document: document.fullText,
vector: encodedDocuments.base.vectors[key] vector: encodedDocuments.base.vectors[key]
) )
} }
......
//
// File.swift
//
//
// Created by Jim Wallace on 2024-04-03.
//
import Foundation
extension String: SNLPDataItem {
public var createdOn: Date { Date.now }
public var id: String { self }
public var fullText: String { self }
}
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
import Foundation import Foundation
class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { final class ContextFreeEncoder<Scalar: BinaryFloatingPoint>: SNLPEncoder {
var dictionary: [String : [Scalar]] var dictionary: [String : [Scalar]]
let width: Int let dimensions: UInt
var zeroes: [Scalar] var zeroes: [Scalar]
var count: Int { dictionary.count } var count: Int { dictionary.count }
...@@ -44,15 +44,15 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { ...@@ -44,15 +44,15 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
var dictionaryToLoad: String var dictionaryToLoad: String
switch source { switch source {
case .glove6B50d: case .glove6B50d:
width = 50 dimensions = 50
dictionaryToLoad = "glove.6B.50d" dictionaryToLoad = "glove.6B.50d"
case .glove6B100d: case .glove6B100d:
width = 100 dimensions = 100
dictionaryToLoad = "glove.6B.100d" dictionaryToLoad = "glove.6B.100d"
} }
zeroes = Array(repeating: Scalar(0), count: width) zeroes = Array(repeating: Scalar(0), count: Int(dimensions))
// Try to load locally first // Try to load locally first
guard let url = Bundle.module.url(forResource: dictionaryToLoad, withExtension: "mmap") else { guard let url = Bundle.module.url(forResource: dictionaryToLoad, withExtension: "mmap") else {
......
...@@ -26,18 +26,18 @@ import Foundation ...@@ -26,18 +26,18 @@ import Foundation
import CoreML import CoreML
class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { //class CoreMLEncoder<Scalar: BinaryFloatingPoint>: SNLPEncoder {
//
var zeroes: [Scalar] // var zeroes: [Scalar]
//
func encodeToken(_ token: String) -> [Scalar] { // func encodeToken(_ token: String) -> [Scalar] {
fatalError("CoreMLEncoder not implemented yet. Get on it.") // fatalError("CoreMLEncoder not implemented yet. Get on it.")
} // }
//
func encodeSentence(_ sentence: String) -> [Scalar] { // func encodeSentence(_ sentence: String) -> [Scalar] {
fatalError("CoreMLEncoder not implemented yet. Get on it.") // fatalError("CoreMLEncoder not implemented yet. Get on it.")
} // }
} //}
//@available(macOS 13.0, *) //@available(macOS 13.0, *)
//public class MiniLMEmbeddings { //public class MiniLMEmbeddings {
......
...@@ -26,8 +26,9 @@ import Foundation ...@@ -26,8 +26,9 @@ import Foundation
import NaturalLanguage import NaturalLanguage
class NaturalLanguageEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { class NaturalLanguageEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
var zeroes: [Scalar] { Array(repeating: Scalar(0), count: 512) } var dimensions: UInt = 512
var zeroes: [Scalar] { Array(repeating: Scalar(0), count: Int(dimensions)) }
@inlinable @inlinable
func encodeToken(_ token: String) -> [Scalar] { func encodeToken(_ token: String) -> [Scalar] {
......
...@@ -23,18 +23,18 @@ ...@@ -23,18 +23,18 @@
import Foundation import Foundation
class OpenAIEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPAsyncEncoder { //class OpenAIEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPAsyncEncoder {
//
//
var zeroes: [Scalar] // var zeroes: [Scalar]
//
func fetchEncodingForToken(_ token: String) async throws -> [Scalar] { // func fetchEncodingForToken(_ token: String) async throws -> [Scalar] {
fatalError("OpenAIEncoder not implemented. Get on it.") // fatalError("OpenAIEncoder not implemented. Get on it.")
} // }
//
func fetchEncodingForSentence(_ sentence: String) async throws -> [Scalar] { // func fetchEncodingForSentence(_ sentence: String) async throws -> [Scalar] {
fatalError("OpenAIEncoder not implemented. Get on it.") // fatalError("OpenAIEncoder not implemented. Get on it.")
} // }
//
//
} //}
...@@ -22,10 +22,10 @@ struct GraphView: View { ...@@ -22,10 +22,10 @@ struct GraphView: View {
for (id, point) in points { for (id, point) in points {
context.fill( context.fill(
Circle().path(in: CGRect(x: point.x - 5, y: point.y - 5, width: 10, height: 10)), Circle().path(in: CGRect(x: point.x - 5, y: point.y - 5, dimensions: 10, height: 10)),
with: .color(.blue) with: .color(.blue)
) )
context.draw(Text("\(id)").bold().foregroundColor(.red), in: CGRect(x: point.x, y: point.y, width: 20, height: 20)) context.draw(Text("\(id)").bold().foregroundColor(.red), in: CGRect(x: point.x, y: point.y, dimensions: 20, height: 20))
} }
} }
.frame(maxWidth: .infinity, maxHeight: .infinity) .frame(maxWidth: .infinity, maxHeight: .infinity)
...@@ -60,7 +60,7 @@ struct VisualizerView: View { ...@@ -60,7 +60,7 @@ struct VisualizerView: View {
updateCount += 1 updateCount += 1
} }
Slider(value: $angle.degrees, in: 0...89) Slider(value: $angle.degrees, in: 0...89)
.frame(width: 100) .frame(dimensions: 100)
} }
.padding() .padding()
ScrollView { ScrollView {
...@@ -74,8 +74,8 @@ struct VisualizerView: View { ...@@ -74,8 +74,8 @@ struct VisualizerView: View {
edges: index.edges(for: level) edges: index.edges(for: level)
) )
.rotation3DEffect(angle, axis: (1, 0, 0), perspective: 0) .rotation3DEffect(angle, axis: (1, 0, 0), perspective: 0)
.frame(width: 600, height: 600, alignment: .top) .frame(dimensions: 600, height: 600, alignment: .top)
.frame(width: 600, height: 600 * cos(angle.radians)) .frame(dimensions: 600, height: 600 * cos(angle.radians))
Divider() Divider()
} }
} }
......
...@@ -61,7 +61,7 @@ final class DurableHNSWCorpusTests: XCTestCase { ...@@ -61,7 +61,7 @@ final class DurableHNSWCorpusTests: XCTestCase {
/// Writing to LMDB /// Writing to LMDB
let transaction = try Transaction.begin(.write, in: env) let transaction = try Transaction.begin(.write, in: env)
var corpus = try DurableHNSWCorpus( let corpus = try DurableHNSWCorpus(
encoding: .glove6B50d, encoding: .glove6B50d,
namespace: "testBasicExample", namespace: "testBasicExample",
in: transaction in: transaction
...@@ -76,7 +76,7 @@ final class DurableHNSWCorpusTests: XCTestCase { ...@@ -76,7 +76,7 @@ final class DurableHNSWCorpusTests: XCTestCase {
/// Reading from LMDB /// Reading from LMDB
let readTransaction = try Transaction.begin(.read, in: env) let readTransaction = try Transaction.begin(.read, in: env)
let readCorpus = try DurableHNSWCorpus( let _ = try DurableHNSWCorpus(
encoding: .glove6B50d, encoding: .glove6B50d,
namespace: "testBasicExample", namespace: "testBasicExample",
in: readTransaction in: readTransaction
...@@ -114,7 +114,7 @@ final class DurableHNSWCorpusTests: XCTestCase { ...@@ -114,7 +114,7 @@ final class DurableHNSWCorpusTests: XCTestCase {
let transaction = try Transaction.begin(.write, in: env) let transaction = try Transaction.begin(.write, in: env)
/// Saving the memory map to disk /// Saving the memory map to disk
var corpus = try DurableHNSWCorpus( let corpus = try DurableHNSWCorpus(
encoder: _documentEncoder, encoder: _documentEncoder,
namespace: "testBasicQueryExample", namespace: "testBasicQueryExample",
in: transaction in: transaction
...@@ -179,7 +179,7 @@ final class DurableHNSWCorpusTests: XCTestCase { ...@@ -179,7 +179,7 @@ final class DurableHNSWCorpusTests: XCTestCase {
let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d) let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = try DurableHNSWCorpus( let corpus = try DurableHNSWCorpus(
encoder: _documentEncoder, encoder: _documentEncoder,
namespace: "subreddit_durable", namespace: "subreddit_durable",
in: transaction in: transaction
......
...@@ -15,7 +15,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -15,7 +15,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
"that enable us to train deep learning algorithms to learn like the human brain." "that enable us to train deep learning algorithms to learn like the human brain."
] ]
var corpus = HNSWCorpus(encoding: .glove6B50d) var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
XCTAssert(corpus.count == 3) XCTAssert(corpus.count == 3)
...@@ -51,7 +51,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -51,7 +51,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
"All science is either physics or stamp collecting. - Ernest Rutherford" "All science is either physics or stamp collecting. - Ernest Rutherford"
] ]
var corpus = HNSWCorpus(encoding: .glove6B50d) var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(twentyQuotes) corpus.addUntokenizedDocuments(twentyQuotes)
XCTAssertEqual(corpus.count, 20) XCTAssertEqual(corpus.count, 20)
...@@ -72,7 +72,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -72,7 +72,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
var corpus = HNSWCorpus(encoding: .glove6B50d) let corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
for submission in submissions { for submission in submissions {
if let text = submission.selftext { if let text = submission.selftext {
...@@ -99,7 +99,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -99,7 +99,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
let query = "I like to read about new technology and artificial intelligence" let query = "I like to read about new technology and artificial intelligence"
let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d) let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = HNSWCorpus(encoder: _documentEncoder) var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
do { do {
...@@ -141,7 +141,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -141,7 +141,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
let query = "I love Albert Einstein!" let query = "I love Albert Einstein!"
let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d) let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = HNSWCorpus(encoder: _documentEncoder) var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoder: _documentEncoder)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
do { do {
...@@ -158,16 +158,16 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -158,16 +158,16 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
func testQueryGuephSubredditCorpus() async throws { func testQueryGuephSubredditCorpus() async throws {
guard let submissionsURL = Bundle.module.url(forResource: "Guelph_submissions", withExtension: "zst") else { guard let submissionsURL = Bundle.module.url(forResource: "Guelph_submissions", withExtension: "zst") else {
fatalError("Failed to find waterloo_submissions.zst in test bundle.") fatalError("Failed to find guelph_submissions.zst in test bundle.")
} }
guard let submissionsData = try? Data(contentsOf: submissionsURL) else { guard let submissionsData = try? Data(contentsOf: submissionsURL) else {
fatalError("Failed to load waterloo_submissions.zst from test bundle.") fatalError("Failed to load guelph_submissions.zst from test bundle.")
} }
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d) let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = HNSWCorpus(encoder: _documentEncoder) let corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoder: _documentEncoder)
for submission in submissions { for submission in submissions {
if let text = submission.selftext { if let text = submission.selftext {
......
...@@ -13,7 +13,7 @@ final class ContextFreeEncoderTests: XCTestCase { ...@@ -13,7 +13,7 @@ final class ContextFreeEncoderTests: XCTestCase {
"that enable us to train deep learning algorithms to learn like the human brain." "that enable us to train deep learning algorithms to learn like the human brain."
] ]
var corpus = DictionaryCorpus(encoding: .glove6B50d) var corpus = DictionaryCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
XCTAssert(corpus.encodedDocuments.count == 3) XCTAssert(corpus.encodedDocuments.count == 3)
...@@ -50,7 +50,7 @@ final class ContextFreeEncoderTests: XCTestCase { ...@@ -50,7 +50,7 @@ final class ContextFreeEncoderTests: XCTestCase {
"All science is either physics or stamp collecting. - Ernest Rutherford" "All science is either physics or stamp collecting. - Ernest Rutherford"
] ]
var corpus = DictionaryCorpus(encoding: .glove6B50d) var corpus = DictionaryCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(twentyQuotes) corpus.addUntokenizedDocuments(twentyQuotes)
...@@ -75,7 +75,7 @@ final class ContextFreeEncoderTests: XCTestCase { ...@@ -75,7 +75,7 @@ final class ContextFreeEncoderTests: XCTestCase {
//print("Errors: \(errors.count)") //print("Errors: \(errors.count)")
let corpus = DictionaryCorpus(encoding: .glove6B50d) let corpus = DictionaryCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
for submission in submissions { for submission in submissions {
if let text = submission.selftext { if let text = submission.selftext {
corpus.addUntokenizedDocument(text) corpus.addUntokenizedDocument(text)
......
...@@ -15,7 +15,7 @@ final class NaturalLanguageEncoderTests: XCTestCase { ...@@ -15,7 +15,7 @@ final class NaturalLanguageEncoderTests: XCTestCase {
] ]
let encoder = NaturalLanguageEncoder<Double>() let encoder = NaturalLanguageEncoder<Double>()
var corpus = DictionaryCorpus(encoder: encoder) var corpus = DictionaryCorpus<Double,NaturalLanguageEncoder,String>(encoder: encoder)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
XCTAssert(corpus.encodedDocuments.count == 3) XCTAssert(corpus.encodedDocuments.count == 3)
...@@ -53,7 +53,7 @@ final class NaturalLanguageEncoderTests: XCTestCase { ...@@ -53,7 +53,7 @@ final class NaturalLanguageEncoderTests: XCTestCase {
] ]
let encoder = NaturalLanguageEncoder<Double>() let encoder = NaturalLanguageEncoder<Double>()
var corpus = DictionaryCorpus(encoder: encoder) var corpus = DictionaryCorpus<Double,NaturalLanguageEncoder,String>(encoder: encoder)
corpus.addUntokenizedDocuments(twentyQuotes) corpus.addUntokenizedDocuments(twentyQuotes)
...@@ -77,7 +77,7 @@ final class NaturalLanguageEncoderTests: XCTestCase { ...@@ -77,7 +77,7 @@ final class NaturalLanguageEncoderTests: XCTestCase {
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
let encoder = NaturalLanguageEncoder<Double>() let encoder = NaturalLanguageEncoder<Double>()
var corpus = DictionaryCorpus(encoder: encoder) let corpus = DictionaryCorpus<Double,NaturalLanguageEncoder,String>(encoder: encoder)
for submission in submissions { for submission in submissions {
if let text = submission.selftext { if let text = submission.selftext {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment