From 256d583ffa2b16c2bb3bad06e61957f3b9a2bc17 Mon Sep 17 00:00:00 2001
From: a252jain <a252jain@uwaterloo.ca>
Date: Fri, 5 Apr 2024 11:36:37 -0400
Subject: [PATCH] add ability to change input dimension

---
 .../SwiftNLP/2. Encoding/CoreMLEncoder.swift  | 24 +++++-----
 Sources/SwiftNLPGenericLLMMacros/Macros.swift | 44 ++++++++++---------
 .../ModelClasses.swift                        |  7 ++-
 3 files changed, 39 insertions(+), 36 deletions(-)

diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift
index bcedf2f7..c696aeda 100644
--- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift	
+++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift	
@@ -44,7 +44,7 @@ public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ ou
     type: "LLMModelPredictionCases")
 
 @freestanding(expression)
-public macro MODEL_VALIDATE_NAME() = #externalMacro(
+public macro MODEL_VALIDATE_NAME_AND_SET_INPUT_SIZE() = #externalMacro(
     module: "SwiftNLPGenericLLMMacros",
     type: "LLMModelNameValidation")
 
@@ -72,8 +72,8 @@ class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
 public class MiniLMEmbeddings {
     
     private let model: String
-    public let tokenizer: BertTokenizer
-    public let inputDimention: Int = 512
+    public var tokenizer: BertTokenizer
+    public var inputDimention: Int = 512
     public let outputDimention: Int = 384
 
     public init(model_type: String) {
@@ -83,24 +83,20 @@ public class MiniLMEmbeddings {
         self.model = model_type;
         self.tokenizer = BertTokenizer(maxLen: self.inputDimention)
         
-        #MODEL_VALIDATE_NAME()
+        #MODEL_VALIDATE_NAME_AND_SET_INPUT_SIZE()
     }
 
      // MARK: - Dense Embeddings
 
     public func encode(sentence: String) async -> [Float]? {
-         // Encode input text as bert tokens
-         let inputTokens = tokenizer.buildModelTokens(sentence: sentence)
-         let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens)
-
-         print(inputIds.count, attentionMask.count)
-
-         // Send tokens through the MLModel
-         let embeddings = generateEmbeddings(inputIds: inputIds, attentionMask: attentionMask)
+        self.tokenizer = BertTokenizer(maxLen: self.inputDimention)
+        // Encode input text as bert tokens
+        let inputTokens = tokenizer.buildModelTokens(sentence: sentence)
+        let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens)
 
-         print(inputIds.count, attentionMask.count)
+        let embeddings = generateEmbeddings(inputIds: inputIds, attentionMask: attentionMask)
 
-         return embeddings
+        return embeddings
     }
 
     public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? {
diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift
index dd7e5fb7..a5360e08 100644
--- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift
+++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift
@@ -7,10 +7,20 @@ import SwiftSyntaxMacros
 public struct LLMModelNameValidation: ExpressionMacro {
     /**
      Example expansion:
-     let valid_models = ["all_MiniLM_L6_v2", "gte-small"];
-     if !valid_models.contains(self.model) {
-         throw fatalError("Model is not valid.");
-     }
+     try! {
+         let valid_models = ["gte-small", "all_MiniLM_L6_v2"];
+         if !valid_models.contains(self.model) {
+             throw fatalError("Model is not valid.");
+         }
+         switch self.model {
+         case "gte-small":
+             self.inputDimention = 128;
+         case "all_MiniLM_L6_v2":
+             self.inputDimention = 512;
+         default:
+             self.inputDimention = 128;
+         }
+     }();
      */
     
     public static func expansion(
@@ -18,26 +28,20 @@ public struct LLMModelNameValidation: ExpressionMacro {
         in context: some MacroExpansionContext
     ) throws -> ExprSyntax {
         
-        var macro = "try! { let valid_models = [";
-        var index = 0;
+        var macro = "try! { switch self.model { "
         
         for (k, v) in LLM_MODEL_CLASSES {
-            macro += "\"\(k)\"";
-            index += 1;
-            if index < LLM_MODEL_CLASSES.count {
-                macro += ", ";
-            }
+            let model_dim = v[LLMModelClassesKey.InputDimension]!
+            macro +=
+                """
+                case \"\(k)\":
+                    self.inputDimention = \(model_dim);
+                """
         }
-        macro += "];";
         
-        return ExprSyntax(stringLiteral:
-            """
-            \(macro)
-            if !valid_models.contains(self.model) {
-                throw fatalError("Model is not valid.");
-            } }();
-            """
-        )
+        macro += "default: throw fatalError(\"Model is not valid\"); } }();"
+        
+        return ExprSyntax(stringLiteral: macro)
     }
 }
 
diff --git a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift
index 4fc208f5..2eb0950c 100644
--- a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift
+++ b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift
@@ -4,6 +4,7 @@ enum LLMModelClassesKey {
     case Model
     case FeatureName
     case URL
+    case InputDimension
 }
 
 let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [
@@ -12,13 +13,15 @@ let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [
         LLMModelClassesKey.Output: all_MiniLM_L6_v2Output.self,
         LLMModelClassesKey.Model: all_MiniLM_L6_v2.self,
         LLMModelClassesKey.FeatureName: "embeddings",
-        LLMModelClassesKey.URL: "all-MiniLM-L6-v2.mlmodelc"
+        LLMModelClassesKey.URL: "all-MiniLM-L6-v2.mlmodelc",
+        LLMModelClassesKey.InputDimension: 512
     ],
     "gte-small": [
         LLMModelClassesKey.Input: float32_modelInput.self,
         LLMModelClassesKey.Output: float32_modelOutput.self,
         LLMModelClassesKey.Model: float32_model.self,
         LLMModelClassesKey.FeatureName: "pooler_output",
-        LLMModelClassesKey.URL: "float32_model.mlmodelc"
+        LLMModelClassesKey.URL: "float32_model.mlmodelc",
+        LLMModelClassesKey.InputDimension: 128
     ]
 ]
-- 
GitLab