From 67a7b7167245a4a28b8c97e0eef237e1433f7496 Mon Sep 17 00:00:00 2001 From: a252jain <a252jain@uwaterloo.ca> Date: Fri, 5 Apr 2024 09:28:38 -0400 Subject: [PATCH] fix build script --- .gitlab-ci.yml | 6 +++ .../SwiftNLP/2. Encoding/CoreMLEncoder.swift | 7 ++++ Sources/SwiftNLPGenericLLMMacros/Macros.swift | 39 +++++++++++++++++++ Sources/SwiftNLPGenericLLMMacros/Main.swift | 3 +- 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 9bca63ec..0d4e1a66 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -24,6 +24,12 @@ build-macOS: test-macOS: stage: test script: + - xcrun coremlcompiler compile Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.mlpackage/ Sources/SwiftNLP/Models + - xcrun coremlcompiler generate Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.mlpackage/ --language Swift Sources/SwiftNLP/Resources + - mv Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.swift Sources/SwiftNLP/2.\ Encoding + - xcrun coremlcompiler compile Sources/SwiftNLP/Resources/float32_model.mlpackage/ Sources/SwiftNLP/Models + - xcrun coremlcompiler generate Sources/SwiftNLP/Resources/float32_model.mlpackage/ --language Swift Sources/SwiftNLP/Resources + - mv Sources/SwiftNLP/Resources/float32_model.swift Sources/SwiftNLP/2.\ Encoding - swift test -c release -Xswiftc -enable-testing # - swift test --sanitize=address -c release -Xswiftc -enable-testing # - swift test --sanitize=thread -c release -Xswiftc -enable-testing diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift index 18e7033a..83988180 100644 --- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift +++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift @@ -43,6 +43,11 @@ public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ ou module: "SwiftNLPGenericLLMMacros", type: "LLMModelPredictionCases") +@freestanding(expression) +public macro MODEL_VALIDATE_NAME() = #externalMacro( + module: "SwiftNLPGenericLLMMacros", + type: "LLMModelNameValidation") + class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { @@ -77,6 +82,8 @@ public class MiniLMEmbeddings { self.model = model_type; self.tokenizer = BertTokenizer(maxLen: self.inputDimention) + + #MODEL_VALIDATE_NAME() } // MARK: - Dense Embeddings diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift index 2d0f6988..73d19fc6 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift @@ -3,6 +3,45 @@ import SwiftSyntax import SwiftSyntaxMacros +@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *) +public struct LLMModelNameValidation: ExpressionMacro { + /** + Example expansion: + let valid_models = ["all_MiniLM_L6_v2", "gte-small"]; + if !valid_models.contains(self.model) { + throw fatalError("Model is not valid."); + } + */ + + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + + var macro = "let valid_models = ["; + var index = 0; + + for (k, v) in LLM_MODEL_CLASSES { + macro += "\"\(k)\""; + index += 1; + if index < LLM_MODEL_CLASSES.count { + macro += ", "; + } + } + macro += "];"; + + return ExprSyntax(stringLiteral: + """ + \(macro) + if !valid_models.contains(self.model) { + throw fatalError("Model is not valid."); + } + """ + ) + } +} + + @available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *) public struct LLMModelPredictionCases: ExpressionMacro { /** diff --git a/Sources/SwiftNLPGenericLLMMacros/Main.swift b/Sources/SwiftNLPGenericLLMMacros/Main.swift index a4618aa2..a1a9228a 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Main.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Main.swift @@ -6,6 +6,7 @@ struct SwiftNLPGenericLLMMacros: CompilerPlugin { init() {} var providingMacros: [SwiftSyntaxMacros.Macro.Type] = [ LLMPredictionFunctions.self, - LLMModelPredictionCases.self + LLMModelPredictionCases.self, + LLMModelNameValidation.self ] } -- GitLab