From 67a7b7167245a4a28b8c97e0eef237e1433f7496 Mon Sep 17 00:00:00 2001
From: a252jain <a252jain@uwaterloo.ca>
Date: Fri, 5 Apr 2024 09:28:38 -0400
Subject: [PATCH] fix build script

---
 .gitlab-ci.yml                                |  6 +++
 .../SwiftNLP/2. Encoding/CoreMLEncoder.swift  |  7 ++++
 Sources/SwiftNLPGenericLLMMacros/Macros.swift | 39 +++++++++++++++++++
 Sources/SwiftNLPGenericLLMMacros/Main.swift   |  3 +-
 4 files changed, 54 insertions(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 9bca63ec..0d4e1a66 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -24,6 +24,12 @@ build-macOS:
 test-macOS:
   stage: test
   script:
+     - xcrun coremlcompiler compile Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.mlpackage/ Sources/SwiftNLP/Models
+     - xcrun coremlcompiler generate Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.mlpackage/ --language Swift Sources/SwiftNLP/Resources
+     - mv Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.swift Sources/SwiftNLP/2.\ Encoding
+     - xcrun coremlcompiler compile Sources/SwiftNLP/Resources/float32_model.mlpackage/ Sources/SwiftNLP/Models
+     - xcrun coremlcompiler generate Sources/SwiftNLP/Resources/float32_model.mlpackage/ --language Swift Sources/SwiftNLP/Resources
+     - mv Sources/SwiftNLP/Resources/float32_model.swift Sources/SwiftNLP/2.\ Encoding
      - swift test -c release -Xswiftc -enable-testing
 #    - swift test --sanitize=address -c release -Xswiftc -enable-testing
 #    - swift test --sanitize=thread -c release -Xswiftc -enable-testing
diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift
index 18e7033a..83988180 100644
--- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift	
+++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift	
@@ -43,6 +43,11 @@ public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ ou
     module: "SwiftNLPGenericLLMMacros",
     type: "LLMModelPredictionCases")
 
+@freestanding(expression)
+public macro MODEL_VALIDATE_NAME() = #externalMacro(
+    module: "SwiftNLPGenericLLMMacros",
+    type: "LLMModelNameValidation")
+
 
 class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
     
@@ -77,6 +82,8 @@ public class MiniLMEmbeddings {
 
         self.model = model_type;
         self.tokenizer = BertTokenizer(maxLen: self.inputDimention)
+        
+        #MODEL_VALIDATE_NAME()
     }
 
      // MARK: - Dense Embeddings
diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift
index 2d0f6988..73d19fc6 100644
--- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift
+++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift
@@ -3,6 +3,45 @@ import SwiftSyntax
 import SwiftSyntaxMacros
 
 
+@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
+public struct LLMModelNameValidation: ExpressionMacro {
+    /**
+     Example expansion:
+     let valid_models = ["all_MiniLM_L6_v2", "gte-small"];
+     if !valid_models.contains(self.model) {
+         throw fatalError("Model is not valid.");
+     }
+     */
+    
+    public static func expansion(
+        of node:  some FreestandingMacroExpansionSyntax,
+        in context: some MacroExpansionContext
+    ) throws -> ExprSyntax {
+        
+        var macro = "let valid_models = [";
+        var index = 0;
+        
+        for (k, v) in LLM_MODEL_CLASSES {
+            macro += "\"\(k)\"";
+            index += 1;
+            if index < LLM_MODEL_CLASSES.count {
+                macro += ", ";
+            }
+        }
+        macro += "];";
+        
+        return ExprSyntax(stringLiteral:
+            """
+            \(macro)
+            if !valid_models.contains(self.model) {
+                throw fatalError("Model is not valid.");
+            }
+            """
+        )
+    }
+}
+
+
 @available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
 public struct LLMModelPredictionCases: ExpressionMacro {
     /**
diff --git a/Sources/SwiftNLPGenericLLMMacros/Main.swift b/Sources/SwiftNLPGenericLLMMacros/Main.swift
index a4618aa2..a1a9228a 100644
--- a/Sources/SwiftNLPGenericLLMMacros/Main.swift
+++ b/Sources/SwiftNLPGenericLLMMacros/Main.swift
@@ -6,6 +6,7 @@ struct SwiftNLPGenericLLMMacros: CompilerPlugin {
     init() {}
     var providingMacros: [SwiftSyntaxMacros.Macro.Type] = [
         LLMPredictionFunctions.self,
-        LLMModelPredictionCases.self
+        LLMModelPredictionCases.self,
+        LLMModelNameValidation.self
     ]
 }
-- 
GitLab