diff --git a/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift b/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift index c6430934c44773aa6d36e1df6e2d0a1d3164ab51..34251fe80a43e9255f3c9dc66d6c8c3ec70bcd89 100644 --- a/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift +++ b/Tests/SwiftNLPTests/AllMiniLM_sampleTest.swift @@ -30,25 +30,27 @@ final class BERT_test: XCTestCase { ] // let docs = ["cat dog", "bee fly"] - var database_embedding: [[Float]] = [] - var query_embedding: [Float] = [] - var embedding_dim: Int = 384 - - var model = MiniLMEmbeddings(model_type: "gte-small") - - query_embedding = await model.encode(sentence: query[0])! - - var i = 1 - //append sentence embedding to database_embedding - for string in docs { - if let vector = await model.encode(sentence: string) { - database_embedding.append(vector) - //print(i) - i += 1 - } else { - fatalError("Error occurred!") + for model in ["gte-small", "all_MiniLM_L6_v2"] { + var database_embedding: [[Float]] = [] + var query_embedding: [Float] = [] + var embedding_dim: Int = 384 + + var model = MiniLMEmbeddings(model_type: "gte-small") + + query_embedding = await model.encode(sentence: query[0])! + + var i = 1 + //append sentence embedding to database_embedding + for string in docs { + if let vector = await model.encode(sentence: string) { + database_embedding.append(vector) + //print(i) + i += 1 + } else { + fatalError("Error occurred!") + } + } - } // // let index = AnnoyIndex<Float>(itemLength: embedding_dim, metric: .euclidean)