From 37a7354e9c1fb1a95f3ee273c38586ebf3412ce6 Mon Sep 17 00:00:00 2001
From: Mingchung Xia <mingchung.xia@gmail.com>
Date: Thu, 14 Mar 2024 17:05:48 -0400
Subject: [PATCH] Metric modifications with Surge

---
 Package.resolved                              |  9 ++++
 Package.swift                                 |  4 +-
 .../HNSW/CartesianDistanceMetric.swift        |  3 ++
 .../HNSW/CosineSimilarityMetric.swift         | 44 +++++++++++++------
 4 files changed, 46 insertions(+), 14 deletions(-)

diff --git a/Package.resolved b/Package.resolved
index 6f78ef27..5f2e68ff 100644
--- a/Package.resolved
+++ b/Package.resolved
@@ -36,6 +36,15 @@
         "version" : "0.1.14"
       }
     },
+    {
+      "identity" : "surge",
+      "kind" : "remoteSourceControl",
+      "location" : "https://github.com/Jounce/Surge.git",
+      "state" : {
+        "revision" : "6e4a47e63da8801afe6188cf039e9f04eb577721",
+        "version" : "2.3.2"
+      }
+    },
     {
       "identity" : "swift-numerics",
       "kind" : "remoteSourceControl",
diff --git a/Package.swift b/Package.swift
index a41369c8..98c5dd1c 100644
--- a/Package.swift
+++ b/Package.swift
@@ -16,7 +16,8 @@ let package = Package(
     dependencies: [
         //.package(url: "https://github.com/jbadger3/SwiftAnnoy", .upToNextMajor(from: "1.0.0")),
         .package(url: "https://github.com/L1MeN9Yu/Elva", .upToNextMajor(from: "2.1.3")),
-        .package(url: "https://github.com/JadenGeller/similarity-topology", .upToNextMajor(from: "0.1.14"))
+        .package(url: "https://github.com/JadenGeller/similarity-topology", .upToNextMajor(from: "0.1.14")),
+        .package(url: "https://github.com/Jounce/Surge.git", .upToNextMajor(from: "2.0.0")),
     ],
     targets: [
         .target(
@@ -27,6 +28,7 @@ let package = Package(
                 .product(name: "HNSWEphemeral", package: "similarity-topology"),
                 .product(name: "HNSWDurable", package: "similarity-topology"),
                 .product(name: "ZSTD", package: "Elva"),
+                .product(name: "Surge", package: "Surge"),
             ],
             resources: [.process("Resources")]
         ),
diff --git a/Sources/SwiftNLP/1. Data Collection/HNSW/CartesianDistanceMetric.swift b/Sources/SwiftNLP/1. Data Collection/HNSW/CartesianDistanceMetric.swift
index 9f1bdf43..79d1c64e 100644
--- a/Sources/SwiftNLP/1. Data Collection/HNSW/CartesianDistanceMetric.swift	
+++ b/Sources/SwiftNLP/1. Data Collection/HNSW/CartesianDistanceMetric.swift	
@@ -26,6 +26,7 @@
 
 import Foundation
 import SimilarityMetric
+import Surge
 
 public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint {
     public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
@@ -34,5 +35,7 @@ public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityM
         let squaredSum = squaredDifferences.reduce(0, +)
         
         return sqrt(squaredSum)
+        
+//        return Vector.Element(Surge.distSq(someItem as! [Double], otherItem as! [Double]))
     }
 }
diff --git a/Sources/SwiftNLP/1. Data Collection/HNSW/CosineSimilarityMetric.swift b/Sources/SwiftNLP/1. Data Collection/HNSW/CosineSimilarityMetric.swift
index 23af62e3..31b7b4d0 100644
--- a/Sources/SwiftNLP/1. Data Collection/HNSW/CosineSimilarityMetric.swift	
+++ b/Sources/SwiftNLP/1. Data Collection/HNSW/CosineSimilarityMetric.swift	
@@ -8,32 +8,50 @@
 import Foundation
 import Accelerate
 import SimilarityMetric
+import Surge
 
 // MARK: May be improved on using Surge/Nifty
 // See https://developer.apple.com/documentation/accelerate/vdsp-snv
 
 public struct CosineSimilarityMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint {
     public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
-        /// Convert vectors to Double for Accelerate functions
+//        /// Convert vectors to Double for Accelerate functions
+//        let someItemDoubles = someItem.map { Double($0) }
+//        let otherItemDoubles = otherItem.map { Double($0) }
+//        
+//        /// Calculate dot product
+//        var dotProduct: Double = 0.0
+//        vDSP_dotprD(someItemDoubles, 1, otherItemDoubles, 1, &dotProduct, vDSP_Length(someItemDoubles.count))
+//        
+//        /// Calculate magnitude of vectors
+//        var someItemMagnitudeSquared: Double = 0.0
+//        var otherItemMagnitudeSquared: Double = 0.0
+//        vDSP_svesqD(someItemDoubles, 1, &someItemMagnitudeSquared, vDSP_Length(someItemDoubles.count))
+//        vDSP_svesqD(otherItemDoubles, 1, &otherItemMagnitudeSquared, vDSP_Length(otherItemDoubles.count))
+//        let someItemMagnitude = sqrt(someItemMagnitudeSquared)
+//        let otherItemMagnitude = sqrt(otherItemMagnitudeSquared)
+//        
+//        /// Calculate the cosine similarity
+//        let cosineSimilarity = dotProduct / (someItemMagnitude * otherItemMagnitude)
+//        
+//        /// Convert back to type Vector.Element
+//        return Vector.Element(cosineSimilarity)
+        
+        // Convert vectors to arrays of Double
         let someItemDoubles = someItem.map { Double($0) }
         let otherItemDoubles = otherItem.map { Double($0) }
         
-        /// Calculate dot product
-        var dotProduct: Double = 0.0
-        vDSP_dotprD(someItemDoubles, 1, otherItemDoubles, 1, &dotProduct, vDSP_Length(someItemDoubles.count))
+        // Calculate dot product using Surge for cosine similarity numerator
+        let dotProduct = Surge.dot(someItemDoubles, otherItemDoubles)
         
-        /// Calculate magnitude of vectors
-        var someItemMagnitudeSquared: Double = 0.0
-        var otherItemMagnitudeSquared: Double = 0.0
-        vDSP_svesqD(someItemDoubles, 1, &someItemMagnitudeSquared, vDSP_Length(someItemDoubles.count))
-        vDSP_svesqD(otherItemDoubles, 1, &otherItemMagnitudeSquared, vDSP_Length(otherItemDoubles.count))
-        let someItemMagnitude = sqrt(someItemMagnitudeSquared)
-        let otherItemMagnitude = sqrt(otherItemMagnitudeSquared)
+        // Manually calculate magnitudes (norms) of the vectors for the denominator
+        let someItemMagnitude = sqrt(Surge.dot(someItemDoubles, someItemDoubles))
+        let otherItemMagnitude = sqrt(Surge.dot(otherItemDoubles, otherItemDoubles))
         
-        /// Calculate the cosine similarity
+        // Calculate cosine similarity
         let cosineSimilarity = dotProduct / (someItemMagnitude * otherItemMagnitude)
         
-        /// Convert back to type Vector.Element
+        // Convert back to type Vector.Element
         return Vector.Element(cosineSimilarity)
     }
 }
-- 
GitLab