From 91a131bbca9ea35ed6d8c6e1bad6e4c27352bdad Mon Sep 17 00:00:00 2001 From: Henry Tian <yuyang.tian@uwaterloo.ca> Date: Fri, 22 Dec 2023 14:32:29 -0500 Subject: [PATCH] Upload New File --- .../2. Embeddings/MiniLMAllEmbeddings.swift | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift diff --git a/Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift b/Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift new file mode 100644 index 00000000..4a98e47e --- /dev/null +++ b/Sources/SwiftNLP/2. Embeddings/MiniLMAllEmbeddings.swift @@ -0,0 +1,55 @@ +import Foundation +import CoreML + + +@available(macOS 12.0, iOS 15.0, *) +public class MiniLMEmbeddings { + public let model: all_MiniLM_L6_v2 + public let tokenizer: BertTokenizer + public let inputDimention: Int = 512 + public let outputDimention: Int = 384 + + public init() { + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = .all + + do { + self.model = try all_MiniLM_L6_v2(configuration: modelConfig) + } catch { + fatalError("Failed to load the Core ML model. Error: \(error.localizedDescription)") + } + + self.tokenizer = BertTokenizer() + } + + // MARK: - Dense Embeddings + + public func encode(sentence: String) async -> [Float]? { + // Encode input text as bert tokens + let inputTokens = tokenizer.buildModelTokens(sentence: sentence) + let (inputIds, attentionMask) = tokenizer.buildModelInputs(from: inputTokens) + + // Send tokens through the MLModel + let embeddings = generateEmbeddings(inputIds: inputIds, attentionMask: attentionMask) + + return embeddings + } + +public func generateEmbeddings(inputIds: MLMultiArray, attentionMask: MLMultiArray) -> [Float]? { + let inputFeatures = all_MiniLM_L6_v2Input(input_ids: inputIds, attention_mask: attentionMask) + + let output = try? model.prediction(input: inputFeatures) + guard let embeddings = output?.embeddings else { + return nil + } + + var embeddingsArray = [Float]() + for index in 0..<embeddings.count { + let value = embeddings[index].floatValue + embeddingsArray.append(Float(value)) + } + + return embeddingsArray +} + +} -- GitLab