From 9b9fb1ec9662293ede4a220e3201c304e2f8e9ab Mon Sep 17 00:00:00 2001
From: a252jain <a252jain@uwaterloo.ca>
Date: Fri, 29 Mar 2024 20:05:43 -0400
Subject: [PATCH] update macro

---
 .../SwiftNLP/2. Encoding/CoreMLEncoder.swift  |  6 +-
 Sources/SwiftNLPGenericLLMMacros/Macros.swift | 60 ++++++++++++-------
 2 files changed, 42 insertions(+), 24 deletions(-)

diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift
index ab8c2c1a..18e7033a 100644
--- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift	
+++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift	
@@ -39,7 +39,7 @@ import CoreML
 
 
 @freestanding(expression)
-public macro MODEL_MAKE_PREDICTION(_ model_type: Any) = #externalMacro(
+public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ output_name: Any) = #externalMacro(
     module: "SwiftNLPGenericLLMMacros",
     type: "LLMModelPredictionCases")
 
@@ -97,11 +97,9 @@ public class MiniLMEmbeddings {
     }
 
     public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? {
-        // let input_class: () = #MODEL_INPUT("input_ids: inputIds, attention_mask: attentionMask")
-        
         var output: MLMultiArray? = nil
         
-        #MODEL_MAKE_PREDICTION("input_ids: inputIds, attention_mask: attentionMask")
+        #MODEL_MAKE_PREDICTION("inputIds", "attentionMask", "output")
         
         if (output === nil) {
             return nil;
diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift
index 189c32d5..14b802ec 100644
--- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift
+++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift
@@ -5,33 +5,54 @@ import SwiftSyntaxMacros
 
 @available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
 public struct LLMModelPredictionCases: ExpressionMacro {
-    //        Example expansion:
-    //        {
-    //            switch self.model {
-    //            case "all_MiniLM_L6_v2":
-    //                let input_class = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask);
-    //                output = try! LLMModel.prediction(input: input_class).featureValue(for: "embeddings")!.multiArrayValue!;
-    //            case "float32_model":
-    //                let input_class = float32_modelInput(input_ids: inputIds, attention_mask: attentionMask);
-    //                output = try! LLMModel.prediction(input: input_class).featureValue(for: "pooler_output")!.multiArrayValue!;
-    //            default:
-    //                output = nil;
-    //            }
-    //        }();
+    /**
+    Example expansion:
+     {
+         switch self.model {
+         case "all_MiniLM_L6_v2":
+             let input_class = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask);
+             output = try! LLMModel.prediction(input: input_class).featureValue(for: "embeddings")!.multiArrayValue!;
+         case "float32_model":
+             let input_class = float32_modelInput(input_ids: inputIds, attention_mask: attentionMask);
+             output = try! LLMModel.prediction(input: input_class).featureValue(for: "pooler_output")!.multiArrayValue!;
+         default:
+             output = nil;
+         }
+     }();
+     */
+    
     
     public static func expansion(
         of node:  some FreestandingMacroExpansionSyntax,
         in context: some MacroExpansionContext
     ) throws -> ExprSyntax {
-        guard let arg = node.argumentList.first?.expression,
-            let segments = arg.as(StringLiteralExprSyntax.self)?.segments,
+        guard let input_arg = node.argumentList.first?.expression,
+            let segments = input_arg.as(StringLiteralExprSyntax.self)?.segments,
             segments.count == 1,
-            case .stringSegment(let literalSegment)? = segments.first
+            case .stringSegment(let input_literal_segment)? = segments.first
         else {
             throw fatalError("Bad argument to macro.")
         }
         
-        let model_key = literalSegment.content.text
+        guard let attention_arg = node.argumentList.dropFirst().first?.expression,
+            let segments = attention_arg.as(StringLiteralExprSyntax.self)?.segments,
+            segments.count == 1,
+            case .stringSegment(let attn_literal_segment)? = segments.first
+        else {
+            throw fatalError("Bad argument to macro.")
+        }
+        
+        guard let output_arg = node.argumentList.dropFirst().dropFirst().first?.expression,
+            let segments = output_arg.as(StringLiteralExprSyntax.self)?.segments,
+            segments.count == 1,
+            case .stringSegment(let output_literal_segment)? = segments.first
+        else {
+            throw fatalError("Bad argument to macro.")
+        }
+        
+        let model_input = input_literal_segment.content.text
+        let model_attn = attn_literal_segment.content.text
+        let model_output = output_literal_segment.content.text
         
         var macro = "{ switch self.model { "
         
@@ -41,8 +62,8 @@ public struct LLMModelPredictionCases: ExpressionMacro {
             macro +=
                 """
                 case \"\(k)\":
-                    let input_class = \(model_class)(\(model_key));
-                    output = try! LLMModel.prediction(input: input_class).featureValue(for: \"\(model_feature)\")!.multiArrayValue!;
+                    let input_class = \(model_class)(input_ids: \(model_input), attention_mask: \(model_attn));
+                    \(model_output) = try! LLMModel.prediction(input: input_class).featureValue(for: \"\(model_feature)\")!.multiArrayValue!;
                 """
         }
         
@@ -75,7 +96,6 @@ public enum LLMPredictionFunctions: DeclarationMacro {
         let model_input_name = LLM_MODEL_CLASSES[model_key]![LLMModelClassesKey.Input]!;
         let model_output_name = LLM_MODEL_CLASSES[model_key]![LLMModelClassesKey.Output]!;
 
-    
         return [
             """
             public static func prediction(input: \(raw: model_input_name)) throws -> \(raw: model_output_name) {
-- 
GitLab