diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift index bcedf2f7be65aeb5f4121d31f50d8a1b280d8702..c696aedaeb885e2601851e17e63aebae5952df94 100644 --- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift +++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift @@ -44,7 +44,7 @@ public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ ou type: "LLMModelPredictionCases") @freestanding(expression) -public macro MODEL_VALIDATE_NAME() = #externalMacro( +public macro MODEL_VALIDATE_NAME_AND_SET_INPUT_SIZE() = #externalMacro( module: "SwiftNLPGenericLLMMacros", type: "LLMModelNameValidation") @@ -72,8 +72,8 @@ class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { public class MiniLMEmbeddings { private let model: String - public let tokenizer: BertTokenizer - public let inputDimention: Int = 512 + public var tokenizer: BertTokenizer + public var inputDimention: Int = 512 public let outputDimention: Int = 384 public init(model_type: String) { @@ -83,24 +83,20 @@ public class MiniLMEmbeddings { self.model = model_type; self.tokenizer = BertTokenizer(maxLen: self.inputDimention) - #MODEL_VALIDATE_NAME() + #MODEL_VALIDATE_NAME_AND_SET_INPUT_SIZE() } // MARK: - Dense Embeddings public func encode(sentence: String) async -> [Float]? { - // Encode input text as bert tokens - let inputTokens = tokenizer.buildModelTokens(sentence: sentence) - let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens) - - print(inputIds.count, attentionMask.count) - - // Send tokens through the MLModel - let embeddings = generateEmbeddings(inputIds: inputIds, attentionMask: attentionMask) + self.tokenizer = BertTokenizer(maxLen: self.inputDimention) + // Encode input text as bert tokens + let inputTokens = tokenizer.buildModelTokens(sentence: sentence) + let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens) - print(inputIds.count, attentionMask.count) + let embeddings = generateEmbeddings(inputIds: inputIds, attentionMask: attentionMask) - return embeddings + return embeddings } public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? { diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift index dd7e5fb71daa316981f50b9a5037ab4074364b49..a5360e08abb3bf001336664f364896e2500a297f 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift @@ -7,10 +7,20 @@ import SwiftSyntaxMacros public struct LLMModelNameValidation: ExpressionMacro { /** Example expansion: - let valid_models = ["all_MiniLM_L6_v2", "gte-small"]; - if !valid_models.contains(self.model) { - throw fatalError("Model is not valid."); - } + try! { + let valid_models = ["gte-small", "all_MiniLM_L6_v2"]; + if !valid_models.contains(self.model) { + throw fatalError("Model is not valid."); + } + switch self.model { + case "gte-small": + self.inputDimention = 128; + case "all_MiniLM_L6_v2": + self.inputDimention = 512; + default: + self.inputDimention = 128; + } + }(); */ public static func expansion( @@ -18,26 +28,20 @@ public struct LLMModelNameValidation: ExpressionMacro { in context: some MacroExpansionContext ) throws -> ExprSyntax { - var macro = "try! { let valid_models = ["; - var index = 0; + var macro = "try! { switch self.model { " for (k, v) in LLM_MODEL_CLASSES { - macro += "\"\(k)\""; - index += 1; - if index < LLM_MODEL_CLASSES.count { - macro += ", "; - } + let model_dim = v[LLMModelClassesKey.InputDimension]! + macro += + """ + case \"\(k)\": + self.inputDimention = \(model_dim); + """ } - macro += "];"; - return ExprSyntax(stringLiteral: - """ - \(macro) - if !valid_models.contains(self.model) { - throw fatalError("Model is not valid."); - } }(); - """ - ) + macro += "default: throw fatalError(\"Model is not valid\"); } }();" + + return ExprSyntax(stringLiteral: macro) } } diff --git a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift index 4fc208f533a9cc30a454e14227bd13d2c73ce7b0..2eb0950cb8d12335d0fd30b06375cff9e76c5723 100644 --- a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift +++ b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift @@ -4,6 +4,7 @@ enum LLMModelClassesKey { case Model case FeatureName case URL + case InputDimension } let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [ @@ -12,13 +13,15 @@ let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [ LLMModelClassesKey.Output: all_MiniLM_L6_v2Output.self, LLMModelClassesKey.Model: all_MiniLM_L6_v2.self, LLMModelClassesKey.FeatureName: "embeddings", - LLMModelClassesKey.URL: "all-MiniLM-L6-v2.mlmodelc" + LLMModelClassesKey.URL: "all-MiniLM-L6-v2.mlmodelc", + LLMModelClassesKey.InputDimension: 512 ], "gte-small": [ LLMModelClassesKey.Input: float32_modelInput.self, LLMModelClassesKey.Output: float32_modelOutput.self, LLMModelClassesKey.Model: float32_model.self, LLMModelClassesKey.FeatureName: "pooler_output", - LLMModelClassesKey.URL: "float32_model.mlmodelc" + LLMModelClassesKey.URL: "float32_model.mlmodelc", + LLMModelClassesKey.InputDimension: 128 ] ]