Skip to content
Snippets Groups Projects
Commit 9515e6d1 authored by Abhinav Jain's avatar Abhinav Jain
Browse files

code cleanup

parent e7ce7297
No related branches found
No related tags found
1 merge request!15Add interface for using generic CoreML LLMs
Pipeline #116066 passed with warnings
......@@ -26,18 +26,6 @@ import Foundation
import CoreML
// To compile the model for this file:
// $ cd Sources/SwiftNLP/Resources
// $ xcrun coremlcompiler generate all-MiniLM-L6-v2.mlpackage/ --language Swift .
// $ cd ../../..
// NEXT:
// get model to work by moving mlmodelc to resources
// finish tests
// make a generic interface based on autogenerated code
// get another coreml available model and test both
@freestanding(expression)
public macro MODEL_MAKE_PREDICTION(_ input_name: Any, _ attention_ids: Any, _ output_name: Any) = #externalMacro(
module: "SwiftNLPGenericLLMMacros",
......@@ -55,13 +43,13 @@ class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
var model: String
func encodeToken(_ token: String) -> [Scalar] {
let tokenization = MiniLMEmbeddings(model_type: self.model).tokenizer.tokenizeToIds(text: token) as! [Scalar]
let tokenization = LLMEmbeddings(model_type: self.model).tokenizer.tokenizeToIds(text: token) as! [Scalar]
return tokenization
}
func encodeSentence(_ sentence: String) -> [Scalar] {
let encoding = Task {
await MiniLMEmbeddings(model_type: self.model).encode(sentence: sentence)
await LLMEmbeddings(model_type: self.model).encode(sentence: sentence)
} as! [Scalar]
return encoding
}
......@@ -69,11 +57,11 @@ class CoreMLEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
@available(macOS 13.0, *)
public class MiniLMEmbeddings {
public class LLMEmbeddings {
private let model: String
public var tokenizer: BertTokenizer
public var inputDimention: Int = 512
public var inputDimention: Int = 512 // 512 is a dummy value, correct value is set by the macro below
public let outputDimention: Int = 384
public init(model_type: String) {
......@@ -81,15 +69,18 @@ public class MiniLMEmbeddings {
modelConfig.computeUnits = .all
self.model = model_type;
// dummy initialization needed here to avoid compilation error
self.tokenizer = BertTokenizer(maxLen: self.inputDimention)
// validate the model type is valid and set the correct input dimension
#MODEL_VALIDATE_NAME_AND_SET_INPUT_SIZE()
// reinitialize with correct input size
self.tokenizer = BertTokenizer(maxLen: self.inputDimention)
}
// MARK: - Dense Embeddings
public func encode(sentence: String) async -> [Float]? {
self.tokenizer = BertTokenizer(maxLen: self.inputDimention)
// Encode input text as bert tokens
let inputTokens = tokenizer.buildModelTokens(sentence: sentence)
let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens)
......@@ -102,6 +93,7 @@ public class MiniLMEmbeddings {
public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? {
var output: MLMultiArray? = nil
// determine which model to use and generate predictions
#MODEL_MAKE_PREDICTION("inputIds", "attentionMask", "output")
if (output === nil) {
......
......@@ -8,17 +8,13 @@ public struct LLMModelNameValidation: ExpressionMacro {
/**
Example expansion:
try! {
let valid_models = ["gte-small", "all_MiniLM_L6_v2"];
if !valid_models.contains(self.model) {
throw fatalError("Model is not valid.");
}
switch self.model {
case "gte-small":
self.inputDimention = 128;
case "all_MiniLM_L6_v2":
self.inputDimention = 512;
default:
self.inputDimention = 128;
throw fatalError("Model is not valid");
}
}();
*/
......@@ -28,9 +24,11 @@ public struct LLMModelNameValidation: ExpressionMacro {
in context: some MacroExpansionContext
) throws -> ExprSyntax {
// generate code
var macro = "try! { switch self.model { "
for (k, v) in LLM_MODEL_CLASSES {
// extract values
let model_dim = v[LLMModelClassesKey.InputDimension]!
macro +=
"""
......@@ -49,7 +47,7 @@ public struct LLMModelNameValidation: ExpressionMacro {
@available(macOS 12, iOS 15.0, tvOS 17.0, watchOS 10.0, *)
public struct LLMModelPredictionCases: ExpressionMacro {
/**
Example expansion:
Example expansion:
{
switch self.model {
case "all_MiniLM_L6_v2":
......@@ -64,11 +62,12 @@ public struct LLMModelPredictionCases: ExpressionMacro {
}();
*/
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
// get first argument from macro invocation
guard let input_arg = node.argumentList.first?.expression,
let segments = input_arg.as(StringLiteralExprSyntax.self)?.segments,
segments.count == 1,
......@@ -77,6 +76,7 @@ public struct LLMModelPredictionCases: ExpressionMacro {
throw fatalError("Bad argument to macro.")
}
// get second argument from macro invocation
guard let attention_arg = node.argumentList.dropFirst().first?.expression,
let segments = attention_arg.as(StringLiteralExprSyntax.self)?.segments,
segments.count == 1,
......@@ -85,6 +85,7 @@ public struct LLMModelPredictionCases: ExpressionMacro {
throw fatalError("Bad argument to macro.")
}
// get third argument from macro invocation
guard let output_arg = node.argumentList.dropFirst().dropFirst().first?.expression,
let segments = output_arg.as(StringLiteralExprSyntax.self)?.segments,
segments.count == 1,
......@@ -93,10 +94,12 @@ public struct LLMModelPredictionCases: ExpressionMacro {
throw fatalError("Bad argument to macro.")
}
// extract parameter values
let model_input = input_literal_segment.content.text
let model_attn = attn_literal_segment.content.text
let model_output = output_literal_segment.content.text
// generate code
var macro = "{ switch self.model { "
for (k, v) in LLM_MODEL_CLASSES {
......@@ -127,7 +130,10 @@ public enum LLMPredictionFunctions: DeclarationMacro {
var ret: String = "";
// generate code
for (k, v) in LLM_MODEL_CLASSES {
// extract values
let model_type_name = v[LLMModelClassesKey.Model]!;
let model_input_name = v[LLMModelClassesKey.Input]!;
let model_output_name = v[LLMModelClassesKey.Output]!;
......
......@@ -28,28 +28,25 @@ final class BERT_test: XCTestCase {
let query = [
"I like to read about new technology and artificial intelligence"
]
// let docs = ["cat dog", "bee fly"]
for model in ["gte-small", "all_MiniLM_L6_v2"] {
var database_embedding: [[Float]] = []
var query_embedding: [Float] = []
var embedding_dim: Int = 384
var model = MiniLMEmbeddings(model_type: "gte-small")
var model = LLMEmbeddings(model_type: "gte-small")
query_embedding = await model.encode(sentence: query[0])!
var i = 1
//append sentence embedding to database_embedding
for string in docs {
if let vector = await model.encode(sentence: string) {
database_embedding.append(vector)
//print(i)
i += 1
} else {
fatalError("Error occurred!")
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment