From bd9296d749ebc0fcb5ed5e28eef4dce990b949e1 Mon Sep 17 00:00:00 2001
From: Jim Wallace <james.wallace@uwaterloo.ca>
Date: Sat, 16 Dec 2023 20:54:08 -0500
Subject: [PATCH] Added subreddit search endpoint

---
 .../1. Data Collection/Reddit/Listing.swift   | 22 ++++++
 .../Reddit/Reddit Content Types.swift         | 17 -----
 .../Reddit/Reddit Enumerations.swift          | 32 ++++++++
 .../RedditClient + Subreddit Search.swift     | 74 +++++++++++++++++++
 .../Reddit/RedditClient.swift                 |  3 +-
 .../Reddit API/RedditClient.swift             | 20 +++++
 6 files changed, 150 insertions(+), 18 deletions(-)
 create mode 100644 Sources/SwiftNLP/1. Data Collection/Reddit/Listing.swift
 delete mode 100644 Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Content Types.swift
 create mode 100644 Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Enumerations.swift
 create mode 100644 Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient + Subreddit Search.swift

diff --git a/Sources/SwiftNLP/1. Data Collection/Reddit/Listing.swift b/Sources/SwiftNLP/1. Data Collection/Reddit/Listing.swift
new file mode 100644
index 00000000..9f58dc4d
--- /dev/null
+++ b/Sources/SwiftNLP/1. Data Collection/Reddit/Listing.swift	
@@ -0,0 +1,22 @@
+struct RedditListing<T: RedditDataItem>: Codable {
+    let kind: String
+    let data: RedditListingData<T>
+    
+    var after: String? { data.after }
+    var before: String? { data.before }
+    var children: [RedditListingDataItem<T>] { data.children }
+}
+
+struct RedditListingData<T: RedditDataItem>: Codable {
+    let modhash: String?
+    let dist: Int
+    let before: String?
+    let after: String?
+    let geo_filter: String?
+    let children: [RedditListingDataItem<T>]
+}
+
+struct RedditListingDataItem<T: RedditDataItem>: Codable {
+    let kind: String
+    let data: T
+}
diff --git a/Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Content Types.swift b/Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Content Types.swift
deleted file mode 100644
index 27361108..00000000
--- a/Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Content Types.swift	
+++ /dev/null
@@ -1,17 +0,0 @@
-//t1_    Comment
-//t2_    Account
-//t3_    Link
-//t4_    Message
-//t5_    Subreddit
-//t6_    Award
-
-enum RedditContentType: String, CustomStringConvertible {
-    case comment    = "t1_"
-    case account    = "t2_"
-    case link       = "t3_"
-    case message    = "t4_"
-    case subreddit  = "t5_"
-    case award      = "t6_"
-        
-    var description: String { rawValue }
-}
diff --git a/Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Enumerations.swift b/Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Enumerations.swift
new file mode 100644
index 00000000..0577a80c
--- /dev/null
+++ b/Sources/SwiftNLP/1. Data Collection/Reddit/Reddit Enumerations.swift	
@@ -0,0 +1,32 @@
+enum RedditContentType: String, CustomStringConvertible {
+    case comment    = "t1"
+    case account    = "t2"
+    case link       = "t3"
+    case message    = "t4"
+    case subreddit  = "t5"
+    case award      = "t6"
+        
+    var description: String { rawValue }
+}
+
+enum ListingSortOrder: String, CustomStringConvertible {
+    
+    case relevance = "relevance"
+    case hot       = "hot"
+    case top       = "top"
+    case new       = "new"
+    case comments  = "comments"
+    
+    var description: String { rawValue }
+}
+
+enum ListingTime: String, CustomStringConvertible {
+    case hour   = "hour"
+    case day    = "day"
+    case week   = "week"
+    case month  = "month"
+    case year   = "year"
+    case all    = "all"
+    
+    var description: String { rawValue }
+}
diff --git a/Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient + Subreddit Search.swift b/Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient + Subreddit Search.swift
new file mode 100644
index 00000000..99926b01
--- /dev/null
+++ b/Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient + Subreddit Search.swift	
@@ -0,0 +1,74 @@
+import Foundation
+
+extension RedditClient {
+    
+    // TODO: This is a stop-gap solution, think through how to make generic over RedditContantTypes
+    // the Reddit API returns listings that can contain mixed results, but for now at least
+    // I don't think we ever need that ... so we just tell the method what types we want, and we get those back
+    func searchSubreddit<T: RedditDataItem>(
+        subreddit: String,
+        q: String,
+        after: String? = nil,
+        before: String? = nil,
+        count: UInt? = nil,
+        limit: UInt? = nil,
+        searchQuery: UUID? = nil,
+        show: String? = "all",
+        sort: ListingSortOrder = .comments,
+        expandSubreddits: Bool? = nil,
+        time: ListingTime = .all,
+        restrictSubreddit: Bool = true
+        //type: String = "link" // TODO: comma-delimited list of result types (sr, link, user)
+    ) async throws -> RedditListing<T> {
+        
+        guard q.count < 512 else {
+            throw RedditClientError(message: "Query length must be less than 512 characters.")
+        }
+        
+        var parameters: [String : String] = [String:String]()
+        
+        parameters["q"] = "\"\(q)\""
+        parameters["sort"] = sort.rawValue
+        parameters["t"] = time.rawValue
+        parameters["restrict_sr"] = String(restrictSubreddit).lowercased()
+        
+        // TODO: We can expand this to include user, subreddit types... is that useful?
+        if T.self == RedditComment.self {
+            parameters["type"] = "comment"
+        }
+        if T.self == RedditSubmission.self {
+            parameters["type"] = "link"
+        }
+        
+        if let after = after {
+            parameters["after"] = after
+        }
+        
+        if let before = before {
+            parameters["before"] = before
+        }
+        
+        if let limit = limit {
+            parameters["limit"] = String(limit)
+        }
+        
+        if let count = count {
+            parameters["count"] = String(count)
+        }
+        
+        if let sr_detail = expandSubreddits {
+            parameters["sr_detail"] = String(sr_detail).lowercased()
+        }
+        
+        let (data, _) = try await _GET(endpoint: "r/\(subreddit)/search", parameters: parameters)
+        
+        do {
+            let redditListing = try JSONDecoder().decode(RedditListing<T>.self, from: data)
+            return redditListing
+            
+        } catch {
+            throw RedditClientError(message: "Unable to decode server response.")
+        }
+    }
+    
+}
diff --git a/Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient.swift b/Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient.swift
index efd81c93..13f9cdd6 100644
--- a/Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient.swift	
+++ b/Sources/SwiftNLP/1. Data Collection/Reddit/RedditClient.swift	
@@ -116,6 +116,7 @@ extension RedditClient {
     //  UTILITY Method
     //  Perform a basic GET given an endpoint and parameters
     //
+    @inlinable
     internal func _GET(endpoint: String, parameters: [String : String]) async throws -> (Data, HTTPURLResponse) {
         guard isAuthenticated else {
             throw RedditClientError(message: "Client not authenticated.")
@@ -144,7 +145,7 @@ extension RedditClient {
         let (data, response) = try await session.data(for: request)
         
         guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else {
-                throw RedditClientError(message: "Bad server response")
+            throw RedditClientError(message: "Bad server response" + response.description)
             }
         
         // Monitor rate limits
diff --git a/Tests/SwiftNLPTests/Reddit API/RedditClient.swift b/Tests/SwiftNLPTests/Reddit API/RedditClient.swift
index 5d0b64f9..65f03c10 100644
--- a/Tests/SwiftNLPTests/Reddit API/RedditClient.swift	
+++ b/Tests/SwiftNLPTests/Reddit API/RedditClient.swift	
@@ -71,4 +71,24 @@ final class RedditClientTest: XCTestCase {
         XCTAssertEqual(response.statusCode, 200)
         //XCTAssertNotNil(client.authResponse)
     }
+    
+    
+    func testSubredditSearch() async throws {
+        
+        let id = ProcessInfo.processInfo.environment["REDDIT_CLIENT_ID"] ?? nil
+        let secret = ProcessInfo.processInfo.environment["REDDIT_CLIENT_SECRET"] ?? nil
+        
+        guard let id = id, let secret = secret else {
+            fatalError("Unable to fetch REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET from ProcessInfo.")
+        }
+        
+        let client = RedditClient(id: id, secret: secret)
+        guard let _ = try? await client.authenticate() else {
+            throw RedditClientError(message: "Error authenticating client.")
+        }
+        
+        let result: RedditListing<RedditSubmission> = try await client.searchSubreddit(subreddit: "uwaterloo", q: "goose", limit: 10)
+
+        XCTAssert(result.data.children.count > 0)
+    }
 }
-- 
GitLab