From 76dba7d13d91c98c59342589ada653f0b55e95a5 Mon Sep 17 00:00:00 2001 From: a252jain <a252jain@uwaterloo.ca> Date: Fri, 5 Apr 2024 10:24:22 -0400 Subject: [PATCH] name changes --- Sources/SwiftNLPGenericLLMMacros/Macros.swift | 10 ++++++---- Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift | 7 +++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/Sources/SwiftNLPGenericLLMMacros/Macros.swift b/Sources/SwiftNLPGenericLLMMacros/Macros.swift index 3aa78cab..0d72c05a 100644 --- a/Sources/SwiftNLPGenericLLMMacros/Macros.swift +++ b/Sources/SwiftNLPGenericLLMMacros/Macros.swift @@ -127,28 +127,30 @@ public enum LLMPredictionFunctions: DeclarationMacro { let model_type_name = v[LLMModelClassesKey.Model]!; let model_input_name = v[LLMModelClassesKey.Input]!; let model_output_name = v[LLMModelClassesKey.Output]!; + let model_url = v[LLMModelClassesKey.URL]!; + ret += """ public static func prediction(input: \(model_input_name)) throws -> \(model_output_name) { - let model = try \(model_type_name)(); + let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"); return try model.prediction(input: input, options: MLPredictionOptions()) } public static func prediction(input: \(model_input_name), options: MLPredictionOptions) throws -> \(model_output_name) { - let model = try \(model_type_name)(); + let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"); let outFeatures: MLFeatureProvider = try model.prediction(input: input, options:options) return \(model_output_name)(features: outFeatures) } @available(macOS 13.6, iOS 17.0, tvOS 17.0, watchOS 10.0, *) public static func prediction(input: \(model_input_name), options: MLPredictionOptions = MLPredictionOptions()) async throws -> \(model_output_name) { - let model = try \(model_type_name)(); + let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"); let outFeatures: MLFeatureProvider? = try await model.prediction(input: input, options:options) return \(model_output_name)(features: outFeatures!) } public static func predictions(inputs: [\(model_input_name)], options: MLPredictionOptions = MLPredictionOptions()) throws -> [\(model_output_name)] { - let model = try \(model_type_name)(); + let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\"); let res = try model.predictions(inputs: inputs, options: options); return res; } diff --git a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift index d165b2c3..4fc208f5 100644 --- a/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift +++ b/Sources/SwiftNLPGenericLLMMacros/ModelClasses.swift @@ -3,6 +3,7 @@ enum LLMModelClassesKey { case Output case Model case FeatureName + case URL } let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [ @@ -10,12 +11,14 @@ let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [ LLMModelClassesKey.Input: all_MiniLM_L6_v2Input.self, LLMModelClassesKey.Output: all_MiniLM_L6_v2Output.self, LLMModelClassesKey.Model: all_MiniLM_L6_v2.self, - LLMModelClassesKey.FeatureName: "embeddings" + LLMModelClassesKey.FeatureName: "embeddings", + LLMModelClassesKey.URL: "all-MiniLM-L6-v2.mlmodelc" ], "gte-small": [ LLMModelClassesKey.Input: float32_modelInput.self, LLMModelClassesKey.Output: float32_modelOutput.self, LLMModelClassesKey.Model: float32_model.self, - LLMModelClassesKey.FeatureName: "pooler_output" + LLMModelClassesKey.FeatureName: "pooler_output", + LLMModelClassesKey.URL: "float32_model.mlmodelc" ] ] -- GitLab