diff --git a/Sources/SwiftNLP/2. Embeddings/tokenizers/BertTokenizer.swift b/Sources/SwiftNLP/2. Embeddings/tokenizers/BertTokenizer.swift new file mode 100644 index 0000000000000000000000000000000000000000..0fe98e43e46ac07043b15b1ceea9d5a1eded7e06 --- /dev/null +++ b/Sources/SwiftNLP/2. Embeddings/tokenizers/BertTokenizer.swift @@ -0,0 +1,289 @@ +import Foundation +import CoreML + +public class BertTokenizer { + private let basicTokenizer = BasicTokenizer() + private let wordpieceTokenizer: WordpieceTokenizer + private let maxLen = 512 + + private let vocab: [String: Int] + private let ids_to_tokens: [Int: String] + + public init() { + let url = Bundle.module.url(forResource: "bert_vocab", withExtension: "txt")! + let vocabTxt = try! String(contentsOf: url) + let tokens = vocabTxt.split(separator: "\n").map { String($0) } + var vocab: [String: Int] = [:] + var ids_to_tokens: [Int: String] = [:] + for (i, token) in tokens.enumerated() { + vocab[token] = i + ids_to_tokens[i] = token + } + self.vocab = vocab + self.ids_to_tokens = ids_to_tokens + self.wordpieceTokenizer = WordpieceTokenizer(vocab: self.vocab) + } + + public func buildModelTokens(sentence: String) -> [Int] { + // Tokenize the sentence to token IDs + var tokenIds = tokenizeToIds(text: sentence) + + // Define the special tokens + let clsTokenId = tokenToId(token: "[CLS]") + let sepTokenId = tokenToId(token: "[SEP]") + // 0 is the ID for [PAD] + let padTokenId = 0 + + // Truncate the tokenIds if it exceeds the max length minus the special tokens, -2 for [CLS] and [SEP] + let maxTokenLength = maxLen - 2 + if tokenIds.count > maxTokenLength { + print("Input sentence is too long \(tokenIds.count + 2) > \(maxLen), truncating.") + tokenIds = Array(tokenIds.prefix(maxTokenLength)) + } + + // Calculate the required number of padding tokens, +2 for [CLS] and [SEP] + let totalTokenCount = tokenIds.count + 2 + let paddingNeeded = max(0, maxLen - totalTokenCount) + + // Create the final token list + var finalTokenIds = [Int]() + finalTokenIds.append(clsTokenId) + finalTokenIds += tokenIds + finalTokenIds.append(sepTokenId) + finalTokenIds += Array(repeating: padTokenId, count: paddingNeeded) + + return finalTokenIds + } + + public func detokenize(alteredTokens: [String]) -> String { + let reconstructedText = convertWordpieceToBasicTokenList(alteredTokens) + return reconstructedText + } + + + public func buildModelInputs(from inputTokens: [Int]) -> (MLMultiArray, MLMultiArray) { + let inputIds = MLMultiArray.from(inputTokens, dims: 2) + + var attentionMaskValues = [Int](repeating: 0, count: inputTokens.count) + for (index, token) in inputTokens.enumerated() { + attentionMaskValues[index] = token == 0 ? 0 : 1 + } + + let attentionMask = MLMultiArray.from(attentionMaskValues, dims: 2) + + return (inputIds, attentionMask) + } + + public func tokenize(text: String) -> [String] { + var tokens: [String] = [] + for token in basicTokenizer.tokenize(text: text) { + for subToken in wordpieceTokenizer.tokenize(word: token) { + tokens.append(subToken) + } + } + return tokens + } + + public func convertTokensToIds(tokens: [String]) throws -> [Int] { + return tokens.map { vocab[$0]! } + } + + func tokenizeToIds(text: String) -> [Int] { + return try! convertTokensToIds(tokens: tokenize(text: text)) + } + + func tokenToId(token: String) -> Int { + return vocab[token]! + } + + func unTokenize(tokens: [Int]) -> [String] { + return tokens.map { ids_to_tokens[$0]! } + } + + func convertWordpieceToBasicTokenList(_ wordpieceTokenList: [String]) -> String { + var tokenList: [String] = [] + var individualToken: String = "" + + for token in wordpieceTokenList { + if token.starts(with: "##") { + individualToken += String(token.suffix(token.count - 2)) + } else { + if individualToken.count > 0 { + tokenList.append(individualToken) + } + + individualToken = token + } + } + + tokenList.append(individualToken) + + return tokenList.joined(separator: " ") + } + +} + +class BasicTokenizer { + let neverSplit = [ + "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", + ] + + func tokenize(text: String) -> [String] { + let splitTokens = text.folding(options: .diacriticInsensitive, locale: nil) + .components(separatedBy: NSCharacterSet.whitespaces) + let tokens: [String] = splitTokens.flatMap({ (token: String) -> [String] in + if neverSplit.contains(token) { + return [token] + } + var tokFrag: [String] = [] + var currentFrag = "" + for char in token.lowercased() { + if char.isLetter || char.isNumber || char == "°" { + currentFrag += String(char) + } else if currentFrag.count > 0 { + tokFrag.append(currentFrag) + tokFrag.append(String(char)) + currentFrag = "" + } else { + tokFrag.append(String(char)) + } + } + if currentFrag.count > 0 { + tokFrag.append(currentFrag) + } + return tokFrag + }) + return tokens + } +} + +class WordpieceTokenizer { + private let unkToken = "[UNK]" + private let maxInputCharsPerWord = 100 + private let vocab: [String: Int] + + init(vocab: [String: Int]) { + self.vocab = vocab + } + + /// `word`: A single token. + /// Warning: this differs from the `pytorch-transformers` implementation. + /// This should have already been passed through `BasicTokenizer`. + func tokenize(word: String) -> [String] { + if word.count > maxInputCharsPerWord { + return [unkToken] + } + var outputTokens: [String] = [] + var isBad = false + var start = 0 + var subTokens: [String] = [] + while start < word.count { + var end = word.count + var CurrentSubStr: String? = nil + while start < end { + var substr = Utils.substr(word, start..<end)! + if start > 0 { + substr = "##\(substr)" + } + if vocab[substr] != nil { + CurrentSubStr = substr + break + } + end -= 1 + } + if CurrentSubStr == nil { + isBad = true + break + } + subTokens.append(CurrentSubStr!) + start = end + } + if isBad { + outputTokens.append(unkToken) + } else { + outputTokens.append(contentsOf: subTokens) + } + return outputTokens + } +} + + +extension MLMultiArray { + static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { + var shape = Array(repeating: 1, count: dims) + shape[shape.count - 1] = arr.count + let mlArray = try! MLMultiArray(shape: shape as [NSNumber], dataType: .int32) + let dataptr = UnsafeMutablePointer<Int32>(OpaquePointer(mlArray.dataPointer)) + for (i, item) in arr.enumerated() { + dataptr[i] = Int32(item) + } + return mlArray + } + + static func toDoubleArray(_ mlArray: MLMultiArray) -> [Double] { + var arr: [Double] = Array(repeating: 0, count: mlArray.count) + let dataptr = UnsafeMutablePointer<Double>(OpaquePointer(mlArray.dataPointer)) + for i in 0..<mlArray.count { + arr[i] = Double(dataptr[i]) + } + return arr + } + + +} + +struct Utils { + /// Time a block in ms + static func time<T>(label: String, _ block: () -> T) -> T { + let startTime = CFAbsoluteTimeGetCurrent() + let result = block() + let diff = (CFAbsoluteTimeGetCurrent() - startTime) * 1_000 + print("[\(label)] \(diff)ms") + return result + } + + /// Time a block in seconds and return (output, time) + static func time<T>(_ block: () -> T) -> (T, Double) { + let startTime = CFAbsoluteTimeGetCurrent() + let result = block() + let diff = CFAbsoluteTimeGetCurrent() - startTime + return (result, diff) + } + + /// Return unix timestamp in ms + static func dateNow() -> Int64 { + // Use `Int` when we don't support 32-bits devices/OSes anymore. + // Int crashes on iPhone 5c. + return Int64(Date().timeIntervalSince1970 * 1000) + } + + /// Clamp a val to [min, max] + static func clamp<T: Comparable>(_ val: T, _ vmin: T, _ vmax: T) -> T { + return min(max(vmin, val), vmax) + } + + /// Fake func that can throw. + static func fakeThrowable<T>(_ input: T) throws -> T { + return input + } + + /// Substring + static func substr(_ s: String, _ r: Range<Int>) -> String? { + let stringCount = s.count + if stringCount < r.upperBound || stringCount < r.lowerBound { + return nil + } + let startIndex = s.index(s.startIndex, offsetBy: r.lowerBound) + let endIndex = s.index(startIndex, offsetBy: r.upperBound - r.lowerBound) + return String(s[startIndex..<endIndex]) + } + + /// Invert a (k, v) dictionary + static func invert<K, V>(_ dict: Dictionary<K, V>) -> Dictionary<V, K> { + var inverted: [V: K] = [:] + for (k, v) in dict { + inverted[v] = k + } + return inverted + } +}