Skip to content
Snippets Groups Projects
Commit 2330a8cd authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Added cosine similarity

parent 7d54bb16
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
Pipeline #114101 failed
//
// CosineSimilarityMetric.swift
//
//
// Created by Mingchung Xia on 2024-03-14.
//
import Foundation
import Accelerate
import SimilarityMetric
// MARK: May be improved on using Surge/Nifty
// See https://developer.apple.com/documentation/accelerate/vdsp-snv
public struct CosineSimilarityMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint {
public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
/// Convert vectors to Double for Accelerate functions
let someItemDoubles = someItem.map { Double($0) }
let otherItemDoubles = otherItem.map { Double($0) }
/// Calculate dot product
var dotProduct: Double = 0.0
vDSP_dotprD(someItemDoubles, 1, otherItemDoubles, 1, &dotProduct, vDSP_Length(someItemDoubles.count))
/// Calculate magnitude of vectors
var someItemMagnitudeSquared: Double = 0.0
var otherItemMagnitudeSquared: Double = 0.0
vDSP_svesqD(someItemDoubles, 1, &someItemMagnitudeSquared, vDSP_Length(someItemDoubles.count))
vDSP_svesqD(otherItemDoubles, 1, &otherItemMagnitudeSquared, vDSP_Length(otherItemDoubles.count))
let someItemMagnitude = sqrt(someItemMagnitudeSquared)
let otherItemMagnitude = sqrt(otherItemMagnitudeSquared)
/// Calculate the cosine similarity
let cosineSimilarity = dotProduct / (someItemMagnitude * otherItemMagnitude)
/// Convert back to type Vector.Element
return Vector.Element(cosineSimilarity)
}
}
...@@ -41,7 +41,8 @@ extension DurableVectorIndex { ...@@ -41,7 +41,8 @@ extension DurableVectorIndex {
public struct DeterministicDurableVectorIndex/*<Vector: Collection & Codable> where Vector.Element: BinaryFloatingPoint*/ { public struct DeterministicDurableVectorIndex/*<Vector: Collection & Codable> where Vector.Element: BinaryFloatingPoint*/ {
public typealias Vector = [Double] public typealias Vector = [Double]
public typealias Index = DurableVectorIndex<CartesianDistanceMetric<Vector>, Vector.Element> // public typealias Index = DurableVectorIndex<CartesianDistanceMetric<Vector>, Vector.Element>
public typealias Index = DurableVectorIndex<CosineSimilarityMetric<Vector>, Vector.Element>
public var base: Index public var base: Index
public var typicalNeighborhoodSize: Int public var typicalNeighborhoodSize: Int
public var size: Int = 0 public var size: Int = 0
...@@ -50,7 +51,8 @@ public struct DeterministicDurableVectorIndex/*<Vector: Collection & Codable> wh ...@@ -50,7 +51,8 @@ public struct DeterministicDurableVectorIndex/*<Vector: Collection & Codable> wh
// private var drng = DeterministicRandomNumberGenerator(seed: 1) // private var drng = DeterministicRandomNumberGenerator(seed: 1)
public init(namespace: String, typicalNeighborhoodSize: Int = 20, in transaction: Transaction) throws { public init(namespace: String, typicalNeighborhoodSize: Int = 20, in transaction: Transaction) throws {
let metric = CartesianDistanceMetric<Vector>() // let metric = CartesianDistanceMetric<Vector>()
let metric = CosineSimilarityMetric<Vector>()
let config = Config.unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize) let config = Config.unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize)
self.base = try Index( self.base = try Index(
namespace: namespace, namespace: namespace,
......
...@@ -283,17 +283,17 @@ final class HNSWTests: XCTestCase { ...@@ -283,17 +283,17 @@ final class HNSWTests: XCTestCase {
let transaction = try Transaction.begin(.write, in: env) let transaction = try Transaction.begin(.write, in: env)
/// Saving the memory map to disk /// Saving the memory map to disk
// var corpus = try DurableHNSWCorpus( var corpus = try DurableHNSWCorpus(
// encoder: _documentEncoder, encoder: _documentEncoder,
// namespace: "testbasicqueryexampledurable", namespace: "testbasicqueryexampledurable",
// in: transaction in: transaction
// ) )
//
// for doc in docs { for doc in docs {
// try corpus.addUntokenizedDocument(doc, in: transaction) try corpus.addUntokenizedDocument(doc, in: transaction)
// } }
//
// corpus.saveDictionaryToDownloads(fileName: "dictionary.mmap") corpus.saveDictionaryToDownloads(fileName: "dictionary.mmap")
try transaction.commit() try transaction.commit()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment