Skip to content
Snippets Groups Projects
Commit a0dd924c authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Reverted ContextFreeEncoder changes

parent d660d689
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
......@@ -134,95 +134,4 @@ extension ContextFreeEncoder {
return nil
}
}
static func writeHNSWToFile(url: URL, hnsw: DeterministicSampleVectorIndex<[Scalar]>) {
let fileManager = FileManager.default
if !fileManager.fileExists(atPath: url.path) {
fileManager.createFile(atPath: url.path, contents: nil, attributes: nil)
}
do {
let fileHandle = try FileHandle(forWritingTo: url)
let count = hnsw.base.vectors.count
let countData = withUnsafeBytes(of: count) { Data($0) }
fileHandle.write(countData)
for vector in hnsw.base.vectors {
let vectorData = vector.map { Float($0) }.withUnsafeBufferPointer { Data(buffer: $0) }
fileHandle.write(vectorData)
}
fileHandle.closeFile()
} catch {
print("Error writing HNSW to file: \(error)")
}
}
static func readHNSWFromFile(_ url: URL) -> DeterministicSampleVectorIndex<[Scalar]> {
do {
let data = try Data(contentsOf: url, options: .alwaysMapped)
let countData = data.prefix(MemoryLayout<Int>.size)
let count: Int = countData.withUnsafeBytes { $0.load(as: Int.self) }
var index = MemoryLayout<Int>.size
var hnsw = DeterministicSampleVectorIndex<[Scalar]>()
// while index < data.count {
// if let stringRange = data[index...].range(of: "\0".data(using: .utf8)!) {
// let keyData = data[index..<stringRange.lowerBound]
// if let key = String(data: keyData, encoding: .utf8) {
// index = stringRange.upperBound
//
// let valuesData = data[index..<(index + 50 * MemoryLayout<Scalar>.size)]
// let values = valuesData.withUnsafeBytes { Array($0.bindMemory(to: Scalar.self)) }
// hnsw.insert(values)
// }
// } else {
// break
// }
// }
for _ in 0..<count {
let vectorData = data[index..<(index + 50 * MemoryLayout<Scalar>.size)]
let vector = vectorData.withUnsafeBytes { Array($0.bindMemory(to: Scalar.self)) }
hnsw.insert(vector)
index += 50 * MemoryLayout<Scalar>.size
}
return hnsw
} catch {
print("Error reading HNSW from file: \(error)")
}
return DeterministicSampleVectorIndex<[Scalar]>()
}
static func readHNSWFromTextFile(from url: URL) -> DeterministicSampleVectorIndex<[Scalar]>? {
do {
let content = try String(contentsOf: url, encoding: .utf8)
let lines = content.split(separator: "\n")
// var data = DeterministicSampleVectorIndex<[Scalar]>()
//
// for line in lines.dropFirst() {
// let tokens = line.split(separator: " ")
// let word = String(tokens[0])
// let vector = tokens.dropFirst().compactMap { Scalar(Double($0)!) }
// data.insert(vector)
// }
//
// return data
var hnsw = DeterministicSampleVectorIndex<[Scalar]>()
for line in lines {
let vector = line.split(separator: " ").compactMap { Scalar(Double($0)!) }
hnsw.insert(vector)
}
return hnsw
} catch {
print("Error loading vectors from text file: \(error)")
return nil
}
}
}
......@@ -26,7 +26,6 @@ import Foundation
class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
var dictionary: [String : [Scalar]]
var hnsw: DeterministicSampleVectorIndex<[Scalar]>
let width: Int
var zeroes: [Scalar]
......@@ -35,14 +34,12 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
public enum PreComputedEmbeddings {
case glove6B50d
case glove6B100d
case hnswindex
//case NLEmbedding
}
init(source: PreComputedEmbeddings) {
dictionary = Dictionary<String,[Scalar]>()
hnsw = DeterministicSampleVectorIndex<[Scalar]>()
var dictionaryToLoad: String
switch source {
......@@ -53,10 +50,6 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
case .glove6B100d:
width = 100
dictionaryToLoad = "glove.6B.100d"
case .hnswindex:
width = 50 // double check the dimension for HNSW
dictionaryToLoad = "hnswindex"
}
zeroes = Array(repeating: Scalar(0), count: width) as! [Scalar]
......@@ -69,7 +62,6 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
return
}
dictionary = ContextFreeEncoder<Scalar>.readDictionaryFromFile(url)
// hnsw = ContextFreeEncoder<Scalar>.readHNSWFromFile(url)
}
subscript(_ token: String) -> [Scalar] {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment