diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift index c696aedaeb885e2601851e17e63aebae5952df94..751bc956dc9d111390675a77dd671657b72df596 100644 --- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift +++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift @@ -26,18 +26,6 @@ import Foundation import CoreML -// To compile the model for this file: -// $ cd Sources/SwiftNLP/Resources -// $ xcrun coremlcompiler generate all-MiniLM-L6-v2.mlpackage/ --language Swift . -// $ cd ../../.. - -// NEXT: -// get model to work by moving mlmodelc to resources -// finish tests -// make a generic interface based on autogenerated code -// get another coreml available model and test both - - @freestanding(expression) public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ output_name: Any) = #externalMacro( module: "SwiftNLPGenericLLMMacros", @@ -55,13 +43,13 @@ class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { var model: String func encodeToken(_ token: String) -> [Scalar] { - let tokenization = MiniLMEmbeddings(model_type: self.model).tokenizer.tokenizeToIds(text: token) as! [Scalar] + let tokenization = LLMEmbeddings(model_type: self.model).tokenizer.tokenizeToIds(text: token) as! [Scalar] return tokenization } func encodeSentence(_ sentence: String) -> [Scalar] { let encoding = Task { - await MiniLMEmbeddings(model_type: self.model).encode(sentence: sentence) + await LLMEmbeddings(model_type: self.model).encode(sentence: sentence) } as! [Scalar] return encoding } @@ -69,11 +57,11 @@ class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { @available(macOS 13.0, *) -public class MiniLMEmbeddings { +public class LLMEmbeddings { private let model: String public var tokenizer: BertTokenizer - public var inputDimention: Int = 512 + public var inputDimention: Int = 512 // 512 is a dummy value, correct value is set by the macro below public let outputDimention: Int = 384 public init(model_type: String) { @@ -81,15 +69,18 @@ public class MiniLMEmbeddings { modelConfig.computeUnits = .all self.model = model_type; + + // dummy initialization needed here to avoid compilation error self.tokenizer = BertTokenizer(maxLen: self.inputDimention) + // validate the model type is valid and set the correct input dimension #MODEL_VALIDATE_NAME_AND_SET_INPUT_SIZE() + + // reinitialize with correct input size + self.tokenizer = BertTokenizer(maxLen: self.inputDimention) } - // MARK: - Dense Embeddings - public func encode(sentence: String) async -> [Float]? { - self.tokenizer = BertTokenizer(maxLen: self.inputDimention) // Encode input text as bert tokens let inputTokens = tokenizer.buildModelTokens(sentence: sentence) let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens) @@ -102,6 +93,7 @@ public class MiniLMEmbeddings { public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? { var output: MLMultiArray? = nil + // determine which model to use and generate predictions #MODEL_MAKE_PREDICTION("inputIds", "attentionMask", "output") if (output === nil) { diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift index a5360e08abb3bf001336664f364896e2500a297f..b756250edf1afa6c71a8e6f3cdcf8a22505c9179 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift @@ -8,17 +8,13 @@ public struct LLMModelNameValidation: ExpressionMacro { /** Example expansion: try! { - let valid_models = ["gte-small", "all_MiniLM_L6_v2"]; - if !valid_models.contains(self.model) { - throw fatalError("Model is not valid."); - } switch self.model { case "gte-small": self.inputDimention = 128; case "all_MiniLM_L6_v2": self.inputDimention = 512; default: - self.inputDimention = 128; + throw fatalError("Model is not valid"); } }(); */ @@ -28,9 +24,11 @@ public struct LLMModelNameValidation: ExpressionMacro { in context: some MacroExpansionContext ) throws -> ExprSyntax { + // generate code var macro = "try! { switch self.model { " for (k, v) in LLM_MODEL_CLASSES { + // extract values let model_dim = v[LLMModelClassesKey.InputDimension]! macro += """ @@ -49,7 +47,7 @@ public struct LLMModelNameValidation: ExpressionMacro { @available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *) public struct LLMModelPredictionCases: ExpressionMacro { /** - Example expansion: + Example expansion: { switch self.model { case "all_MiniLM_L6_v2": @@ -64,11 +62,12 @@ public struct LLMModelPredictionCases: ExpressionMacro { }(); */ - public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> ExprSyntax { + + // get first argument from macro invocation guard let input_arg = node.argumentList.first?.expression, let segments = input_arg.as(StringLiteralExprSyntax.self)?.segments, segments.count == 1, @@ -77,6 +76,7 @@ public struct LLMModelPredictionCases: ExpressionMacro { throw fatalError("Bad argument to macro.") } + // get second argument from macro invocation guard let attention_arg = node.argumentList.dropFirst().first?.expression, let segments = attention_arg.as(StringLiteralExprSyntax.self)?.segments, segments.count == 1, @@ -85,6 +85,7 @@ public struct LLMModelPredictionCases: ExpressionMacro { throw fatalError("Bad argument to macro.") } + // get third argument from macro invocation guard let output_arg = node.argumentList.dropFirst().dropFirst().first?.expression, let segments = output_arg.as(StringLiteralExprSyntax.self)?.segments, segments.count == 1, @@ -93,10 +94,12 @@ public struct LLMModelPredictionCases: ExpressionMacro { throw fatalError("Bad argument to macro.") } + // extract parameter values let model_input = input_literal_segment.content.text let model_attn = attn_literal_segment.content.text let model_output = output_literal_segment.content.text + // generate code var macro = "{ switch self.model { " for (k, v) in LLM_MODEL_CLASSES { @@ -127,7 +130,10 @@ public enum LLMPredictionFunctions: DeclarationMacro { var ret: String = ""; + // generate code for (k, v) in LLM_MODEL_CLASSES { + + // extract values let model_type_name = v[LLMModelClassesKey.Model]!; let model_input_name = v[LLMModelClassesKey.Input]!; let model_output_name = v[LLMModelClassesKey.Output]!; diff --git a/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift b/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift index 34251fe80a43e9255f3c9dc66d6c8c3ec70bcd89..4c6be656c5bc3893176c81e54a12ee296eea1e88 100644 --- a/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift +++ b/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift @@ -28,28 +28,25 @@ final class BERT_test: XCTestCase { let query = [ "I like to read about new technology and artificial intelligence" ] - // let docs = ["cat dog", "bee fly"] for model in ["gte-small", "all_MiniLM_L6_v2"] { var database_embedding: [[Float]] = [] var query_embedding: [Float] = [] var embedding_dim: Int = 384 - var model = MiniLMEmbeddings(model_type: "gte-small") + var model = LLMEmbeddings(model_type: "gte-small") query_embedding = await model.encode(sentence: query[0])! var i = 1 - //append sentence embedding to database_embedding + for string in docs { if let vector = await model.encode(sentence: string) { database_embedding.append(vector) - //print(i) i += 1 } else { fatalError("Error occurred!") } - } }