Skip to content
Snippets Groups Projects
Commit 91a131bb authored by Henry Tian's avatar Henry Tian
Browse files

Upload New File

parent ea7bfe7d
No related branches found
No related tags found
1 merge request!5Allminilm
Pipeline #108448 failed
import Foundation
import CoreML
@available(macOS 12.0, iOS 15.0, *)
public class MiniLMEmbeddings {
public let model: all_MiniLM_L6_v2
public let tokenizer: BertTokenizer
public let inputDimention: Int = 512
public let outputDimention: Int = 384
public init() {
let modelConfig = MLModelConfiguration()
modelConfig.computeUnits = .all
do {
self.model = try all_MiniLM_L6_v2(configuration: modelConfig)
} catch {
fatalError("Failed to load the Core ML model. Error: \(error.localizedDescription)")
}
self.tokenizer = BertTokenizer()
}
// MARK: - Dense Embeddings
public func encode(sentence: String) async -> [Float]? {
// Encode input text as bert tokens
let inputTokens = tokenizer.buildModelTokens(sentence: sentence)
let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens)
// Send tokens through the MLModel
let embeddings = generateEmbeddings(inputIds: inputIds, attentionMask: attentionMask)
return embeddings
}
public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? {
let inputFeatures = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask)
let output = try? model.prediction(input: inputFeatures)
guard let embeddings = output?.embeddings else {
return nil
}
var embeddingsArray = [Float]()
for index in 0..<embeddings.count {
let value = embeddings[index].floatValue
embeddingsArray.append(Float(value))
}
return embeddingsArray
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment