-
Abhinav Jain authoredAbhinav Jain authored
Macros.swift 6.52 KiB
import CoreML
import SwiftSyntax
import SwiftSyntaxMacros
@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
public struct LLMModelNameValidation: ExpressionMacro {
/**
Example expansion:
let valid_models = ["all_MiniLM_L6_v2", "gte-small"];
if !valid_models.contains(self.model) {
throw fatalError("Model is not valid.");
}
*/
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
var macro = "try! { let valid_models = [";
var index = 0;
for (k, v) in LLM_MODEL_CLASSES {
macro += "\"\(k)\"";
index += 1;
if index < LLM_MODEL_CLASSES.count {
macro += ", ";
}
}
macro += "];";
return ExprSyntax(stringLiteral:
"""
\(macro)
if !valid_models.contains(self.model) {
throw fatalError("Model is not valid.");
} }();
"""
)
}
}
@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
public struct LLMModelPredictionCases: ExpressionMacro {
/**
Example expansion:
{
switch self.model {
case "all_MiniLM_L6_v2":
let input_class = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask);
output = try! LLMModel.prediction(input: input_class).featureValue(for: "embeddings")!.multiArrayValue!;
case "gte-small":
let input_class = float32_modelInput(input_ids: inputIds, attention_mask: attentionMask);
output = try! LLMModel.prediction(input: input_class).featureValue(for: "pooler_output")!.multiArrayValue!;
default:
output = nil;
}
}();
*/
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
guard let input_arg = node.argumentList.first?.expression,
let segments = input_arg.as(StringLiteralExprSyntax.self)?.segments,
segments.count == 1,
case .stringSegment(let input_literal_segment)? = segments.first
else {
throw fatalError("Bad argument to macro.")
}
guard let attention_arg = node.argumentList.dropFirst().first?.expression,
let segments = attention_arg.as(StringLiteralExprSyntax.self)?.segments,
segments.count == 1,
case .stringSegment(let attn_literal_segment)? = segments.first
else {
throw fatalError("Bad argument to macro.")
}
guard let output_arg = node.argumentList.dropFirst().dropFirst().first?.expression,
let segments = output_arg.as(StringLiteralExprSyntax.self)?.segments,
segments.count == 1,
case .stringSegment(let output_literal_segment)? = segments.first
else {
throw fatalError("Bad argument to macro.")
}
let model_input = input_literal_segment.content.text
let model_attn = attn_literal_segment.content.text
let model_output = output_literal_segment.content.text
var macro = "{ switch self.model { "
for (k, v) in LLM_MODEL_CLASSES {
let model_class = v[LLMModelClassesKey.Input]!
let model_feature = v[LLMModelClassesKey.FeatureName]!
macro +=
"""
case \"\(k)\":
let input_class = \(model_class)(input_ids: \(model_input), attention_mask: \(model_attn));
\(model_output) = try! LLMModel.prediction(input: input_class).featureValue(for: \"\(model_feature)\")!.multiArrayValue!;
"""
}
macro += "default: output = nil; } }();"
return ExprSyntax(stringLiteral: macro)
}
}
@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
public enum LLMPredictionFunctions: DeclarationMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> [DeclSyntax] {
var ret: String = "";
for (k, v) in LLM_MODEL_CLASSES {
let model_type_name = v[LLMModelClassesKey.Model]!;
let model_input_name = v[LLMModelClassesKey.Input]!;
let model_output_name = v[LLMModelClassesKey.Output]!;
let model_url = v[LLMModelClassesKey.URL]!;
ret +=
"""
public static func prediction(input: \(model_input_name)) throws -> \(model_output_name) {
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
return try model.prediction(input: input, options: MLPredictionOptions())
}
public static func prediction(input: \(model_input_name), options: MLPredictionOptions) throws -> \(model_output_name) {
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
let outFeatures: MLFeatureProvider = try model.prediction(input: input, options:options)
return \(model_output_name)(features: outFeatures)
}
@available(macOS 13.6, iOS 17.0, tvOS 17.0, watchOS 10.0, *)
public static func prediction(input: \(model_input_name), options: MLPredictionOptions = MLPredictionOptions()) async throws -> \(model_output_name) {
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
let outFeatures: MLFeatureProvider? = try await model.prediction(input: input, options:options)
return \(model_output_name)(features: outFeatures!)
}
public static func predictions(inputs: [\(model_input_name)], options: MLPredictionOptions = MLPredictionOptions()) throws -> [\(model_output_name)] {
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
let res = try model.predictions(inputs: inputs, options: options);
return res;
}
""";
}
return [DeclSyntax(stringLiteral: ret)];
}
}