diff --git a/Package.swift b/Package.swift index 4d5544271158be117ae139df1eb23f2602f58521..22c20410d598d896db7a5771d83e22564be7dd69 100644 --- a/Package.swift +++ b/Package.swift @@ -46,7 +46,7 @@ let package = Package( ), .testTarget( name: "SwiftNLPTests", - dependencies: ["SwiftNLP", "SwiftNLPGenericLLMMacros"], + dependencies: ["SwiftNLPGenericLLMMacros", "SwiftNLP"], resources: [ .process("Resources"), ]), diff --git a/Sources/SwiftNLP/2. Encoding/GenericModel.swift b/Sources/SwiftNLP/2. Encoding/GenericModel.swift index 601ded3519ac48c80a3de13530bae5698fe3c80a..65ded46bfabfa29aabbe6d6f9b783f0fbd1c3d37 100644 --- a/Sources/SwiftNLP/2. Encoding/GenericModel.swift +++ b/Sources/SwiftNLP/2. Encoding/GenericModel.swift @@ -2,12 +2,11 @@ import CoreML @freestanding(declaration, names: arbitrary) -public macro MODEL_PREDICTION_FUNCTIONS(_ model_type: Any) = #externalMacro( +public macro MODEL_PREDICTION_FUNCTIONS() = #externalMacro( module: "SwiftNLPGenericLLMMacros", type: "LLMPredictionFunctions") struct LLMModel { - #MODEL_PREDICTION_FUNCTIONS("all_MiniLM_L6_v2") - #MODEL_PREDICTION_FUNCTIONS("float32_model") + #MODEL_PREDICTION_FUNCTIONS() } diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift index 14b802ecd179f98b66a527f48bcf5a8d5c87e1f2..2d0f6988d617ed30788e1100ab85b57dc0b38498 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift @@ -12,7 +12,7 @@ public struct LLMModelPredictionCases: ExpressionMacro { case "all_MiniLM_L6_v2": let input_class = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask); output = try! LLMModel.prediction(input: input_class).featureValue(for: "embeddings")!.multiArrayValue!; - case "float32_model": + case "gte-small": let input_class = float32_modelInput(input_ids: inputIds, attention_mask: attentionMask); output = try! LLMModel.prediction(input: input_class).featureValue(for: "pooler_output")!.multiArrayValue!; default: @@ -81,47 +81,42 @@ public enum LLMPredictionFunctions: DeclarationMacro { of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> [DeclSyntax] { - - guard let arg = node.argumentList.first?.expression, - let segments = arg.as(StringLiteralExprSyntax.self)?.segments, - segments.count == 1, - case .stringSegment(let literalSegment)? = segments.first - else { - throw fatalError("Bad argument to macro.") - } - let model_key = literalSegment.content.text - - let model_type_name = LLM_MODEL_CLASSES[model_key]![LLMModelClassesKey.Model]!; - let model_input_name = LLM_MODEL_CLASSES[model_key]![LLMModelClassesKey.Input]!; - let model_output_name = LLM_MODEL_CLASSES[model_key]![LLMModelClassesKey.Output]!; + var ret: String = ""; + + for (k, v) in LLM_MODEL_CLASSES { + let model_type_name = v[LLMModelClassesKey.Model]!; + let model_input_name = v[LLMModelClassesKey.Input]!; + let model_output_name = v[LLMModelClassesKey.Output]!; + ret += + """ + public static func prediction(input: \(model_input_name)) throws -> \(model_output_name) { + let model = try \(model_type_name)(); + return try model.prediction(input: input, options: MLPredictionOptions()) + } - return [ - """ - public static func prediction(input: \(raw: model_input_name)) throws -> \(raw: model_output_name) { - let model = try \(raw: model_type_name)(); - return try model.prediction(input: input, options: MLPredictionOptions()) - } + public static func prediction(input: \(model_input_name), options: MLPredictionOptions) throws -> \(model_output_name) { + let model = try \(model_type_name)(); + let outFeatures: MLFeatureProvider = try model.prediction(input: input, options:options) + return \(model_output_name)(features: outFeatures) + } - public static func prediction(input: \(raw: model_input_name), options: MLPredictionOptions) throws -> \(raw: model_output_name) { - let model = try \(raw: model_type_name)(); - let outFeatures: MLFeatureProvider = try model.prediction(input: input, options:options) - return \(raw: model_output_name)(features: outFeatures) - } + @available(macOS 13.6, iOS 17.0, tvOS 17.0, watchOS 10.0, *) + public static func prediction(input: \(model_input_name), options: MLPredictionOptions = MLPredictionOptions()) async throws -> \(model_output_name) { + let model = try \(model_type_name)(); + let outFeatures: MLFeatureProvider? = try await model.prediction(input: input, options:options) + return \(model_output_name)(features: outFeatures!) + } - @available(macOS 13.6, iOS 17.0, tvOS 17.0, watchOS 10.0, *) - public static func prediction(input: \(raw: model_input_name), options: MLPredictionOptions = MLPredictionOptions()) async throws -> \(raw: model_output_name) { - let model = try \(raw: model_type_name)(); - let outFeatures: MLFeatureProvider? = try await model.prediction(input: input, options:options) - return \(raw: model_output_name)(features: outFeatures!) - } + public static func predictions(inputs: [\(model_input_name)], options: MLPredictionOptions = MLPredictionOptions()) throws -> [\(model_output_name)] { + let model = try \(model_type_name)(); + let res = try model.predictions(inputs: inputs, options: options); + return res; + } + + """; + } - public static func predictions(inputs: [\(raw: model_input_name)], options: MLPredictionOptions = MLPredictionOptions()) throws -> [\(raw: model_output_name)] { - let model = try \(raw: model_type_name)(); - let res = try model.predictions(inputs: inputs, options: options); - return res; - } - """ - ] + return [DeclSyntax(stringLiteral: ret)]; } } diff --git a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift index 94839a63b7f2ec171865b2c3593915f697fe209b..d165b2c3a82b5bc61c7626513b61c0cf08a6c084 100644 --- a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift +++ b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift @@ -12,7 +12,7 @@ let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [ LLMModelClassesKey.Model: all_MiniLM_L6_v2.self, LLMModelClassesKey.FeatureName: "embeddings" ], - "float32_model": [ + "gte-small": [ LLMModelClassesKey.Input: float32_modelInput.self, LLMModelClassesKey.Output: float32_modelOutput.self, LLMModelClassesKey.Model: float32_model.self,