Skip to content
Snippets Groups Projects
Macros.swift 6.52 KiB
import CoreML
import SwiftSyntax
import SwiftSyntaxMacros


@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
public struct LLMModelNameValidation: ExpressionMacro {
    /**
     Example expansion:
     let valid_models = ["all_MiniLM_L6_v2", "gte-small"];
     if !valid_models.contains(self.model) {
         throw fatalError("Model is not valid.");
     }
     */
    
    public static func expansion(
        of node:  some FreestandingMacroExpansionSyntax,
        in context: some MacroExpansionContext
    ) throws -> ExprSyntax {
        
        var macro = "try! { let valid_models = [";
        var index = 0;
        
        for (k, v) in LLM_MODEL_CLASSES {
            macro += "\"\(k)\"";
            index += 1;
            if index < LLM_MODEL_CLASSES.count {
                macro += ", ";
            }
        }
        macro += "];";
        
        return ExprSyntax(stringLiteral:
            """
            \(macro)
            if !valid_models.contains(self.model) {
                throw fatalError("Model is not valid.");
            } }();
            """
        )
    }
}


@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
public struct LLMModelPredictionCases: ExpressionMacro {
    /**
    Example expansion:
     {
         switch self.model {
         case "all_MiniLM_L6_v2":
             let input_class = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask);
             output = try! LLMModel.prediction(input: input_class).featureValue(for: "embeddings")!.multiArrayValue!;
         case "gte-small":
             let input_class = float32_modelInput(input_ids: inputIds, attention_mask: attentionMask);
             output = try! LLMModel.prediction(input: input_class).featureValue(for: "pooler_output")!.multiArrayValue!;
         default:
             output = nil;
         }
     }();
     */
    
    
    public static func expansion(
        of node:  some FreestandingMacroExpansionSyntax,
        in context: some MacroExpansionContext
    ) throws -> ExprSyntax {
        guard let input_arg = node.argumentList.first?.expression,
            let segments = input_arg.as(StringLiteralExprSyntax.self)?.segments,
            segments.count == 1,
            case .stringSegment(let input_literal_segment)? = segments.first
        else {
            throw fatalError("Bad argument to macro.")
        }
        
        guard let attention_arg = node.argumentList.dropFirst().first?.expression,
            let segments = attention_arg.as(StringLiteralExprSyntax.self)?.segments,
            segments.count == 1,
            case .stringSegment(let attn_literal_segment)? = segments.first
        else {
            throw fatalError("Bad argument to macro.")
        }
        
        guard let output_arg = node.argumentList.dropFirst().dropFirst().first?.expression,
            let segments = output_arg.as(StringLiteralExprSyntax.self)?.segments,
            segments.count == 1,
            case .stringSegment(let output_literal_segment)? = segments.first
        else {
            throw fatalError("Bad argument to macro.")
        }
        
        let model_input = input_literal_segment.content.text
        let model_attn = attn_literal_segment.content.text
        let model_output = output_literal_segment.content.text
        
        var macro = "{ switch self.model { "
        
        for (k, v) in LLM_MODEL_CLASSES {
            let model_class = v[LLMModelClassesKey.Input]!
            let model_feature = v[LLMModelClassesKey.FeatureName]!
            macro +=
                """
                case \"\(k)\":
                    let input_class = \(model_class)(input_ids: \(model_input), attention_mask: \(model_attn));
                    \(model_output) = try! LLMModel.prediction(input: input_class).featureValue(for: \"\(model_feature)\")!.multiArrayValue!;
                """
        }
        
        macro += "default: output = nil; } }();"
        
        return ExprSyntax(stringLiteral: macro)
            
    }
}


@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
public enum LLMPredictionFunctions: DeclarationMacro {
    public static func expansion(
        of node: some FreestandingMacroExpansionSyntax,
        in context: some MacroExpansionContext
    ) throws -> [DeclSyntax] {
        
        var ret: String = "";
        
        for (k, v) in LLM_MODEL_CLASSES {
            let model_type_name = v[LLMModelClassesKey.Model]!;
            let model_input_name = v[LLMModelClassesKey.Input]!;
            let model_output_name = v[LLMModelClassesKey.Output]!;
            let model_url = v[LLMModelClassesKey.URL]!;
        
            ret +=
                """
                public static func prediction(input: \(model_input_name)) throws -> \(model_output_name) {
                    let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
                    return try model.prediction(input: input, options: MLPredictionOptions())
                }

                public static func prediction(input: \(model_input_name), options: MLPredictionOptions) throws -> \(model_output_name) {
                    let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
                    let outFeatures: MLFeatureProvider = try model.prediction(input: input, options:options)
                    return \(model_output_name)(features: outFeatures)
                }

                @available(macOS 13.6, iOS 17.0, tvOS 17.0, watchOS 10.0, *)
                public static func prediction(input: \(model_input_name), options: MLPredictionOptions = MLPredictionOptions()) async throws -> \(model_output_name) {
                    let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
                    let outFeatures: MLFeatureProvider? = try await model.prediction(input: input, options:options)
                    return \(model_output_name)(features: outFeatures!)
                }

                public static func predictions(inputs: [\(model_input_name)], options: MLPredictionOptions = MLPredictionOptions()) throws -> [\(model_output_name)] {
                    let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"));
                    let res = try model.predictions(inputs: inputs, options: options);
                    return res;
                }
                
                """;
        }

        return [DeclSyntax(stringLiteral: ret)];
    }
}