diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 9bca63ec8fcb1b8404a9d5ba4608726d282a9814..0d4e1a66e684f59b820bc86e3f4ebf1880e87e50 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -24,6 +24,12 @@ build-macOS: test-macOS: stage: test script: + - xcrun coremlcompiler compile Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.mlpackage/ Sources/SwiftNLP/Models + - xcrun coremlcompiler generate Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.mlpackage/ --language Swift Sources/SwiftNLP/Resources + - mv Sources/SwiftNLP/Resources/all-MiniLM-L6-v2.swift Sources/SwiftNLP/2.\ Encoding + - xcrun coremlcompiler compile Sources/SwiftNLP/Resources/float32_model.mlpackage/ Sources/SwiftNLP/Models + - xcrun coremlcompiler generate Sources/SwiftNLP/Resources/float32_model.mlpackage/ --language Swift Sources/SwiftNLP/Resources + - mv Sources/SwiftNLP/Resources/float32_model.swift Sources/SwiftNLP/2.\ Encoding - swift test -c release -Xswiftc -enable-testing # - swift test --sanitize=address -c release -Xswiftc -enable-testing # - swift test --sanitize=thread -c release -Xswiftc -enable-testing diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift index 18e7033a801a4f6dc0123f846f981d80c1246b40..83988180fee2144af770e3511838d15ed0aee305 100644 --- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift +++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift @@ -43,6 +43,11 @@ public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ ou module: "SwiftNLPGenericLLMMacros", type: "LLMModelPredictionCases") +@freestanding(expression) +public macro MODEL_VALIDATE_NAME() = #externalMacro( + module: "SwiftNLPGenericLLMMacros", + type: "LLMModelNameValidation") + class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { @@ -77,6 +82,8 @@ public class MiniLMEmbeddings { self.model = model_type; self.tokenizer = BertTokenizer(maxLen: self.inputDimention) + + #MODEL_VALIDATE_NAME() } // MARK: - Dense Embeddings diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift index 2d0f6988d617ed30788e1100ab85b57dc0b38498..73d19fc69016fa2585bc951803cd775426b0c41c 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift @@ -3,6 +3,45 @@ import SwiftSyntax import SwiftSyntaxMacros +@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *) +public struct LLMModelNameValidation: ExpressionMacro { + /** + Example expansion: + let valid_models = ["all_MiniLM_L6_v2", "gte-small"]; + if !valid_models.contains(self.model) { + throw fatalError("Model is not valid."); + } + */ + + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + + var macro = "let valid_models = ["; + var index = 0; + + for (k, v) in LLM_MODEL_CLASSES { + macro += "\"\(k)\""; + index += 1; + if index < LLM_MODEL_CLASSES.count { + macro += ", "; + } + } + macro += "];"; + + return ExprSyntax(stringLiteral: + """ + \(macro) + if !valid_models.contains(self.model) { + throw fatalError("Model is not valid."); + } + """ + ) + } +} + + @available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *) public struct LLMModelPredictionCases: ExpressionMacro { /** diff --git a/Sources/SwiftNLPGenericLLMMacros/Main.swift b/Sources/SwiftNLPGenericLLMMacros/Main.swift index a4618aa221b937eb1d15cdc212debe19c6a535b1..a1a9228af8c740e5644f388e9e93a9389c69b780 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Main.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Main.swift @@ -6,6 +6,7 @@ struct SwiftNLPGenericLLMMacros: CompilerPlugin { init() {} var providingMacros: [SwiftSyntaxMacros.Macro.Type] = [ LLMPredictionFunctions.self, - LLMModelPredictionCases.self + LLMModelPredictionCases.self, + LLMModelNameValidation.self ] }