Skip to content
Snippets Groups Projects
Commit 605f7730 authored by Jim Wallace's avatar Jim Wallace
Browse files

SNLPCorpus generic over SNLPEncoder

parent 8473b4d2
No related branches found
No related tags found
No related merge requests found
Pipeline #108991 passed
...@@ -26,11 +26,8 @@ import Foundation ...@@ -26,11 +26,8 @@ import Foundation
protocol SNLPCorpus: Collection { protocol SNLPCorpus: Collection {
associatedtype Scalar: BinaryFloatingPoint & Codable associatedtype Scalar: BinaryFloatingPoint & Codable
associatedtype Encoder: SNLPEncoder where Encoder.Scalar == Scalar
var _documentEncoder: ContextFreeEncoder<Scalar> { get set }
var zeroes: [Scalar] { get } var zeroes: [Scalar] { get }
var count: Int { get } var count: Int { get }
mutating func addUntokenizedDocument(_ document: String) mutating func addUntokenizedDocument(_ document: String)
...@@ -38,10 +35,6 @@ protocol SNLPCorpus: Collection { ...@@ -38,10 +35,6 @@ protocol SNLPCorpus: Collection {
} }
extension SNLPCorpus { extension SNLPCorpus {
/**
Default implementation -- just ask the DocumentEnder what a zero looks like, and return that
*/
var zeroes: [Scalar] { _documentEncoder.zeroes }
/** /**
Adds a series of untokenized documents to the corpus, using default tokenization and text processing Adds a series of untokenized documents to the corpus, using default tokenization and text processing
......
...@@ -24,12 +24,9 @@ ...@@ -24,12 +24,9 @@
import Foundation import Foundation
class DictionaryCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { class DictionaryCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
typealias Scalar = Scalar
typealias Encoder = ContextFreeEncoder<Scalar>
var _documentEncoder: ContextFreeEncoder<Scalar> private var _documentEncoder: any SNLPEncoder
var zeroes: [Scalar] { _documentEncoder.zeroes as! [Scalar] }
var encodedDocuments: [Int : [Scalar] ] = [:] var encodedDocuments: [Int : [Scalar] ] = [:]
var count: Int { encodedDocuments.count } var count: Int { encodedDocuments.count }
...@@ -39,9 +36,13 @@ class DictionaryCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { ...@@ -39,9 +36,13 @@ class DictionaryCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
_documentEncoder = ContextFreeEncoder(source: encoding) _documentEncoder = ContextFreeEncoder(source: encoding)
} }
init(encoder: any SNLPEncoder, scalar: Scalar.Type = Double.self) {
_documentEncoder = encoder
}
@inlinable @inlinable
func addUntokenizedDocument(_ document: String) { func addUntokenizedDocument(_ document: String) {
encodedDocuments[ encodedDocuments.count ] = _documentEncoder.encodeSentence(document) encodedDocuments[ encodedDocuments.count ] = (_documentEncoder.encodeSentence(document) as! [Scalar])
} }
} }
...@@ -14,7 +14,8 @@ final class ContextFreeEncoderTests: XCTestCase { ...@@ -14,7 +14,8 @@ final class ContextFreeEncoderTests: XCTestCase {
"that enable us to train deep learning algorithms to learn like the human brain." "that enable us to train deep learning algorithms to learn like the human brain."
] ]
var corpus = DictionaryCorpus(encoding: .glove6B50d) var encoder = NaturalLanguageEncoder()
var corpus = DictionaryCorpus(encoder)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
XCTAssert(corpus.encodedDocuments.count == 3) XCTAssert(corpus.encodedDocuments.count == 3)
...@@ -51,7 +52,8 @@ final class ContextFreeEncoderTests: XCTestCase { ...@@ -51,7 +52,8 @@ final class ContextFreeEncoderTests: XCTestCase {
"All science is either physics or stamp collecting. - Ernest Rutherford" "All science is either physics or stamp collecting. - Ernest Rutherford"
] ]
var corpus = DictionaryCorpus(encoding: .glove6B50d) var encoder = NaturalLanguageEncoder()
var corpus = DictionaryCorpus(encoder)
corpus.addUntokenizedDocuments(twentyQuotes) corpus.addUntokenizedDocuments(twentyQuotes)
...@@ -74,7 +76,9 @@ final class ContextFreeEncoderTests: XCTestCase { ...@@ -74,7 +76,9 @@ final class ContextFreeEncoderTests: XCTestCase {
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
var corpus = DictionaryCorpus(encoding: .glove6B50d) var encoder = NaturalLanguageEncoder()
var corpus = DictionaryCorpus(encoder)
for submission in submissions { for submission in submissions {
if let text = submission.selftext { if let text = submission.selftext {
corpus.addUntokenizedDocument(text) corpus.addUntokenizedDocument(text)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment