From 9b9fb1ec9662293ede4a220e3201c304e2f8e9ab Mon Sep 17 00:00:00 2001 From: a252jain <a252jain@uwaterloo.ca> Date: Fri, 29 Mar 2024 20:05:43 -0400 Subject: [PATCH] update macro --- .../SwiftNLP/2. Encoding/CoreMLEncoder.swift | 6 +- Sources/SwiftNLPGenericLLMMacros/Macros.swift | 60 ++++++++++++------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift index ab8c2c1a..18e7033a 100644 --- a/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift +++ b/Sources/SwiftNLP/2. Encoding/CoreMLEncoder.swift @@ -39,7 +39,7 @@ import CoreML @freestanding(expression) -public macro MODEL_MAKE_PREDICTION(_ model_type: Any) = #externalMacro( +public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ output_name: Any) = #externalMacro( module: "SwiftNLPGenericLLMMacros", type: "LLMModelPredictionCases") @@ -97,11 +97,9 @@ public class MiniLMEmbeddings { } public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? { - // let input_class: () = #MODEL_INPUT("input_ids: inputIds, attention_mask: attentionMask") - var output: MLMultiArray? = nil - #MODEL_MAKE_PREDICTION("input_ids: inputIds, attention_mask: attentionMask") + #MODEL_MAKE_PREDICTION("inputIds", "attentionMask", "output") if (output === nil) { return nil; diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift index 189c32d5..14b802ec 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift @@ -5,33 +5,54 @@ import SwiftSyntaxMacros @available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *) public struct LLMModelPredictionCases: ExpressionMacro { - // Example expansion: - // { - // switch self.model { - // case "all_MiniLM_L6_v2": - // let input_class = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask); - // output = try! LLMModel.prediction(input: input_class).featureValue(for: "embeddings")!.multiArrayValue!; - // case "float32_model": - // let input_class = float32_modelInput(input_ids: inputIds, attention_mask: attentionMask); - // output = try! LLMModel.prediction(input: input_class).featureValue(for: "pooler_output")!.multiArrayValue!; - // default: - // output = nil; - // } - // }(); + /** + Example expansion: + { + switch self.model { + case "all_MiniLM_L6_v2": + let input_class = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask); + output = try! LLMModel.prediction(input: input_class).featureValue(for: "embeddings")!.multiArrayValue!; + case "float32_model": + let input_class = float32_modelInput(input_ids: inputIds, attention_mask: attentionMask); + output = try! LLMModel.prediction(input: input_class).featureValue(for: "pooler_output")!.multiArrayValue!; + default: + output = nil; + } + }(); + */ + public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> ExprSyntax { - guard let arg = node.argumentList.first?.expression, - let segments = arg.as(StringLiteralExprSyntax.self)?.segments, + guard let input_arg = node.argumentList.first?.expression, + let segments = input_arg.as(StringLiteralExprSyntax.self)?.segments, segments.count == 1, - case .stringSegment(let literalSegment)? = segments.first + case .stringSegment(let input_literal_segment)? = segments.first else { throw fatalError("Bad argument to macro.") } - let model_key = literalSegment.content.text + guard let attention_arg = node.argumentList.dropFirst().first?.expression, + let segments = attention_arg.as(StringLiteralExprSyntax.self)?.segments, + segments.count == 1, + case .stringSegment(let attn_literal_segment)? = segments.first + else { + throw fatalError("Bad argument to macro.") + } + + guard let output_arg = node.argumentList.dropFirst().dropFirst().first?.expression, + let segments = output_arg.as(StringLiteralExprSyntax.self)?.segments, + segments.count == 1, + case .stringSegment(let output_literal_segment)? = segments.first + else { + throw fatalError("Bad argument to macro.") + } + + let model_input = input_literal_segment.content.text + let model_attn = attn_literal_segment.content.text + let model_output = output_literal_segment.content.text var macro = "{ switch self.model { " @@ -41,8 +62,8 @@ public struct LLMModelPredictionCases: ExpressionMacro { macro += """ case \"\(k)\": - let input_class = \(model_class)(\(model_key)); - output = try! LLMModel.prediction(input: input_class).featureValue(for: \"\(model_feature)\")!.multiArrayValue!; + let input_class = \(model_class)(input_ids: \(model_input), attention_mask: \(model_attn)); + \(model_output) = try! LLMModel.prediction(input: input_class).featureValue(for: \"\(model_feature)\")!.multiArrayValue!; """ } @@ -75,7 +96,6 @@ public enum LLMPredictionFunctions: DeclarationMacro { let model_input_name = LLM_MODEL_CLASSES[model_key]![LLMModelClassesKey.Input]!; let model_output_name = LLM_MODEL_CLASSES[model_key]![LLMModelClassesKey.Output]!; - return [ """ public static func prediction(input: \(raw: model_input_name)) throws -> \(raw: model_output_name) { -- GitLab