From 91a131bbca9ea35ed6d8c6e1bad6e4c27352bdad Mon Sep 17 00:00:00 2001
From: Henry Tian <yuyang.tian@uwaterloo.ca>
Date: Fri, 22 Dec 2023 14:32:29 -0500
Subject: [PATCH] Upload New File

---
 .../2. Embeddings/MiniLMAllEmbeddings.swift   | 55 +++++++++++++++++++
 1 file changed, 55 insertions(+)
 create mode 100644 Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift

diff --git a/Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift b/Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift
new file mode 100644
index 00000000..4a98e47e
--- /dev/null
+++ b/Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift	
@@ -0,0 +1,55 @@
+import Foundation
+import CoreML
+
+
+@available(macOS 12.0, iOS 15.0, *)
+public class MiniLMEmbeddings {
+    public let model: all_MiniLM_L6_v2
+    public let tokenizer: BertTokenizer
+    public let inputDimention: Int = 512
+    public let outputDimention: Int = 384
+
+    public init() {
+        let modelConfig = MLModelConfiguration()
+        modelConfig.computeUnits = .all
+
+        do {
+            self.model = try all_MiniLM_L6_v2(configuration: modelConfig)
+        } catch {
+            fatalError("Failed to load the Core ML model. Error: \(error.localizedDescription)")
+        }
+
+        self.tokenizer = BertTokenizer()
+    }
+
+    // MARK: - Dense Embeddings
+
+    public func encode(sentence: String) async -> [Float]? {
+        // Encode input text as bert tokens
+        let inputTokens = tokenizer.buildModelTokens(sentence: sentence)
+        let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens)
+
+        // Send tokens through the MLModel
+        let embeddings = generateEmbeddings(inputIds: inputIds, attentionMask: attentionMask)
+
+        return embeddings
+    }
+
+public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? {
+    let inputFeatures = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask)
+    
+    let output = try? model.prediction(input: inputFeatures)
+    guard let embeddings = output?.embeddings else {
+        return nil
+    }
+    
+    var embeddingsArray = [Float]()
+    for index in 0..<embeddings.count {
+        let value = embeddings[index].floatValue
+        embeddingsArray.append(Float(value))
+    }
+    
+    return embeddingsArray
+}
+
+}
-- 
GitLab