Skip to content
Snippets Groups Projects
Commit 76dba7d1 authored by Abhinav Jain's avatar Abhinav Jain
Browse files

name changes

parent 6ea88ef5
No related branches found
No related tags found
1 merge request!15Add interface for using generic CoreML LLMs
Pipeline #116025 canceled
......@@ -127,28 +127,30 @@ public enum LLMPredictionFunctions: DeclarationMacro {
let model_type_name = v[LLMModelClassesKey.Model]!;
let model_input_name = v[LLMModelClassesKey.Input]!;
let model_output_name = v[LLMModelClassesKey.Output]!;
let model_url = v[LLMModelClassesKey.URL]!;
ret +=
"""
public static func prediction(input: \(model_input_name)) throws -> \(model_output_name) {
let model = try \(model_type_name)();
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\");
return try model.prediction(input: input, options: MLPredictionOptions())
}
public static func prediction(input: \(model_input_name), options: MLPredictionOptions) throws -> \(model_output_name) {
let model = try \(model_type_name)();
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\");
let outFeatures: MLFeatureProvider = try model.prediction(input: input, options:options)
return \(model_output_name)(features: outFeatures)
}
@available(macOS 13.6, iOS 17.0, tvOS 17.0, watchOS 10.0, *)
public static func prediction(input: \(model_input_name), options: MLPredictionOptions = MLPredictionOptions()) async throws -> \(model_output_name) {
let model = try \(model_type_name)();
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\");
let outFeatures: MLFeatureProvider? = try await model.prediction(input: input, options:options)
return \(model_output_name)(features: outFeatures!)
}
public static func predictions(inputs: [\(model_input_name)], options: MLPredictionOptions = MLPredictionOptions()) throws -> [\(model_output_name)] {
let model = try \(model_type_name)();
let model = try \(model_type_name)(contentsOf: URL(fileURLWithPath: \"Sources/SwiftNLP/Models/\(model_url)\");
let res = try model.predictions(inputs: inputs, options: options);
return res;
}
......
......@@ -3,6 +3,7 @@ enum LLMModelClassesKey {
case Output
case Model
case FeatureName
case URL
}
let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [
......@@ -10,12 +11,14 @@ let LLM_MODEL_CLASSES: [String: [LLMModelClassesKey: Any]] = [
LLMModelClassesKey.Input: all_MiniLM_L6_v2Input.self,
LLMModelClassesKey.Output: all_MiniLM_L6_v2Output.self,
LLMModelClassesKey.Model: all_MiniLM_L6_v2.self,
LLMModelClassesKey.FeatureName: "embeddings"
LLMModelClassesKey.FeatureName: "embeddings",
LLMModelClassesKey.URL: "all-MiniLM-L6-v2.mlmodelc"
],
"gte-small": [
LLMModelClassesKey.Input: float32_modelInput.self,
LLMModelClassesKey.Output: float32_modelOutput.self,
LLMModelClassesKey.Model: float32_model.self,
LLMModelClassesKey.FeatureName: "pooler_output"
LLMModelClassesKey.FeatureName: "pooler_output",
LLMModelClassesKey.URL: "float32_model.mlmodelc"
]
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment