Skip to main content

Scrape Reddit With Selenium

How to Scrape Reddit With Selenium

Reddit is a platform where people can share news, opinions, ideas and all sorts of other things. In fact, Reddit has a reputation as the "Front Page of the Internet".

Today, we're going to learn how to scrape data from Reddit:


TLDR: How to Scrape Reddit

  • We can actuallly fetch a batch of posts from Reddit by adding .json to the end of our url.
  • In the example below, we've got a production ready Reddit scraper:
from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor

user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")

proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

@dataclass
class CommentData:
name: str = ""
body: str = ""
upvotes: int = 0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

class DataPipeline:

def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False

def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return

keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)

if not file_exists:
writer.writeheader()

for item in data_to_save:
writer.writerow(asdict(item))

self.csv_file_open = False

def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False

def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()

def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()

def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url

#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)

if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]

article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)

data_pipeline.add_data(article_data)



else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()

def process_post(post_object, location="us", retries=3):
tries = 0
success = False

permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"

link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")

comment_pipeline = DataPipeline(csv_filename=f"{filename}.csv")

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(get_scrapeops_url(r_url, location=location))
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)

comments_list = comments[1]["data"]["children"]

for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = CommentData(
name=data["author"],
body=data["body"],
upvotes=data["ups"]
)
comment_pipeline.add_data(comment_data)
comment_pipeline.close_pipeline()
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1

finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")


#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us", retries=3):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))

with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(process_post, reader, [location] * len(reader), [retries] * len(reader))


########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]
BATCH_SIZE = 100
MAX_THREADS = 11

AGGREGATED_FEEDS = []

for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")

for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file, max_workers=MAX_THREADS)

This scraper is production ready.

  • To customize your feeds, simply add the subreddits you want into the FEEDS array.
  • To change add or remove threads, change the MAX_THREADS constant (it is set to run with 11 by default) If your CPU takes advantage of hyperthreading, you can actually use more threads than you have cores.
  • For instance, the code this article was written using an i3 6 core CPU with hyperthreading. With hyperthreading, the CPU is capable of running 12 threads.
  • We chose to use 11, so the rest of the PC can run smoothly on one thread while the scraper operates. If you want the top 100 posts instead of 10, change BATCH_SIZE to 100.

How To How To Architect Our Reddit Scraper

Our Reddit scraper needs to be able to do the following things:

  1. Fetch a batch of Reddit posts
  2. Retrieve data about each individual post
  3. Clean Our data
  4. Save our data to a CSV file

We're going to go through the process of building a scraper that does all these things from start to finish using Selenium. While Selenium isn't exactly lightweight, it is highly optimized. Selenium is equipped with more than enough to handle the simple JSON parsing with power and speed.

We can scrape Reddit with Selenium and gather a ton of data...FAST.


Understanding How To Scrape Reddit

Step 1: How To Request Reddit Pages

When you lookup a feed on Reddit, you're given content in small batches. Whenever you scroll down, the site automatically fetches more content for your to read.

Infinite scrollers like this often present challenges when trying to scrape. While Selelnium gives us the ability to scroll the page, there's actually a faster way. We'll dig into that soon enough. Take a look at the image below.

Redditr/news Subreddit

As you can see, the url looks like this:

https://www.reddit.com/r/news/?rdt=51809

By tweaking just a couple of things, we can completely change the result.

Lets change the url to:

https://www.reddit.com/r/news.json

Take a look at the result below.

Reddit feed

By simply changing /?rdt=51809 to .json, we've turn Reddit into a full blown feed!


Step 2: How To Extract Data From Reddit Feeds

Since our data is stored in the form of JSON, to retrieve it, we just need to know how to handle simple dictionaries in Python. Python's dictionaries are key-value pairs just like JSON objects, there is one main difference between the two types: JSON objects don't natively have quotation marks in the keys.

Take the following example below:

Dictionary

{"name": "John Doe", "age": 30}

The same object as traditional JSON would be:

JSON

{name: "John Doe", age: 30}

In order to index a dict in Python we simply use its keys. Our entire list of content comes in our resp or response object. With a standard HTTP client like Python Requests, we would simply use resp.json().

With Selenium, we get the page and then we use json.loads() to convert the text string into a dict. To access each article in the list, all we have to do is change our index:

  • First Article: resp["data"]["children"][0]
  • Second Article: resp["data"]["children"][1]
  • Third Article: resp["data"]["children"][2]

We can follow this method all the way up to last last child and we'll be finished collecting articles.


Step 3: How To Control Pagination

Think back to the link we looked at previously,

https://www.reddit.com/r/news.json

We can add a limit parameter to this url for finer control of our results.

If we want 100 news results, our url would look like this:

https://www.reddit.com/r/news.json?limit=100

This doesn't give us actual pages to sort through, but we do get long lists of results that we can control. All we have to do is pass the text into json.loads().


Setting Up Our Reddit Scraper Project

Let's get started with our project. First, we'll make a new project folder. I'll call mine reddit-scraper.

mkdir reddit-scraper

Next, we'll make a new virtual environment.

python -m venv venv

Activate the new virtual environment.

source venv/bin/activate

We only have two dependencies to install. Make sure you have the lastest version of Chromedriver installed. Then you can install Selenium using pip.

pip install selenium

If you have any issues with installing Chromedriver and have an executable path error, take a look at our article about how to tackle this problem.


Build A Reddit Crawler

When scraping Reddit, we actually need to build two scrapers.

  1. We need a crawler to identify all of the different posts. This crawler will fetch lists of posts, extract data from them, and save that data to a CSV file.

  2. After our data is saved into the CSV file, it will be read by an individual post scraper. The post scraper will go and fetch individual posts along with commments and some other metadata.


Step 1: Create A Reddit Data Parser

First, we need a simple parser for our Reddit data.

  • In the script below, we have one function, get_posts(). This function takes one argument, feed and a kwarg, retries which is set to 3 by default.
  • While we still have retries left, we try to fetch the json feed from Reddit. If we run out of retries, we raise an Exception and allow the crawler to crash.
  • Our json data comes nested inside of a pre tag, so we simply use Selenium to find the pre tag before loading our text.
  • In our json response, we have an array of json objects called children. Each item in this array represents an individual Reddit post.
  • Each post contains a data field. From that data field, we pull the title, author_fullname, permalink, and upvote_ratio
  • Later on, these items will make up the data we wish to save from our search, but we'll just print them for now.
rom selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor

user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")

proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

#get posts from a subreddit
def get_posts(feed, retries=3):
tries = 0
success = False

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json"
driver.get(url)
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)

if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]

#extract individual fields from the site data
name = data["title"]
author = data["author_fullname"]
permalink = data["permalink"]
upvote_ratio = data["upvote_ratio"]

#print the extracted data
print(f"Name: {name}")
print(f"Author: {author}")
print(f"Permalink: {permalink}")
print(f"Upvote Ratio: {upvote_ratio}")

else:
logger.warning(f"Failed response: {resp.status_code}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()


########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]

for feed in FEEDS:
get_posts(feed)

If you run this script, you should get an output similar to this:

Reddit Parser Output

As you can see in the image above, we extract the following from each post:

  • Name
  • Author
  • Permalink
  • Upvote Ratio

Step 2: Add Pagination

Now that we're getting results, we need finer control over our results. If we want 100 results, we should get 100. If we only want 10 results, we should get 10.

We can accomplish this by adding the limit parameter to our url. Let's refactor our get_posts() function to take an additional keyword, limit.

Taking our limit into account, our url will now look like this:

https://www.reddit.com/r/{feed}.json?limit={limit}

Here is the updated script:

from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor

user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")

proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

#get posts from a subreddit
def get_posts(feed, limit=100, retries=3):
tries = 0
success = False

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(url)
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)

if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]

#extract individual fields from the site data
name = data["title"]
author = data["author_fullname"]
permalink = data["permalink"]
upvote_ratio = data["upvote_ratio"]

#print the extracted data
print(f"Name: {name}")
print(f"Author: {author}")
print(f"Permalink: {permalink}")
print(f"Upvote Ratio: {upvote_ratio}")

else:
logger.warning(f"Failed response: {resp.status_code}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()


########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]
BATCH_SIZE = 2

for feed in FEEDS:
get_posts(feed, limit=BATCH_SIZE)
  • In the code above, we now add a limit parameter to get_posts().
  • We also declare a new constant in our main, BATCH_SIZE. We pass our BATCH_SIZE into get_posts() to control the size of our results.

Feel free to try changing the batch size and examining your results. limit is incredibly important. We don't want to scrape through hundreds of results if we only need 10... and we certainly don't want to try scraping hundreds of results when we're only limited to 10!


Step 3: Storing the Scraped Data

Now that we're retrieving the proper data, we need to be able to store that data. To store this data, we're going to add a SearchData class and a DataPipeline class as well.

SearchData is going to be relatively simple, all it's going to do is hold individual data. DataPipeline will be doing the real heavy lifting.

Our DataPipeline class will be doing all the work of removing duplicates and saving our SearchData objects to CSV.

In this example, we utilize SearchData to hold the data we've extracted and we then pass it into the DataPipeline.

from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor

user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")

proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

class DataPipeline:

def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False

def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return

keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)

if not file_exists:
writer.writeheader()

for item in data_to_save:
writer.writerow(asdict(item))

self.csv_file_open = False

def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False

def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()

def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()

#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None):
tries = 0
success = False

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(url)
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)

if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]

article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)

data_pipeline.add_data(article_data)



else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()


########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]
BATCH_SIZE = 2

for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
  • In the code example above, we create a DataPipeline and pass it into get_posts().
  • From inside get_posts(), we get our post data and turn it into a SearchData object.
  • This object then gets passed into the DataPipeline which removes duplicates and saves everything to a CSV.
  • Once we've gone through and processed the posts, we go through and close the pipeline.
  • We also replaced author_fullname with author. This allows us to see the actual display name of each poster.

Step 4: Bypassing Anti-Bots

Anti-bots are used to detect malicious software. While our crawler is not malicious, we are requesting json data with custom batches and it does make us look a bit abnormal.

In order to prevent from getting blocked, we're going to pass all of these requests through the ScrapeOps Proxy API. This API gives us the benefit of rotating IP addresses, and it always selects the best proxy available.

In this code snippet, we create a simple function, get_scrapeops_url(). This function takes in a regular url and uses simple string formatting to create a proxied url.

def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url

Here is the full code example.

from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor

user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")

proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

class DataPipeline:

def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False

def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return

keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)

if not file_exists:
writer.writeheader()

for item in data_to_save:
writer.writerow(asdict(item))

self.csv_file_open = False

def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False

def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()

def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()

def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url

#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)

if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]

article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)

data_pipeline.add_data(article_data)



else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()


########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]
BATCH_SIZE = 2

for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()

In the example above, we create a proxy_url by passing our url into get_scrapeops_url(). We then pass the result directly into driver.get() so Selenium will navigate to the proxied url instead of the regular one.


Step 5: Production Run

Now that we've got a working crawler, let's give it a production run. I'm going to change our batch size to 100.

########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]
BATCH_SIZE = 100

for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipline=feed_pipeline)
feed_pipeline.close_pipeline()

Now let's take a look at the output file.

CSV Export of Reddit Scraped Results


Build A Reddit Post Scraper

Now it's time to build our post scraper. The goal of this scraper is quite simple. It needs to use multithreading to do the following:

  1. Read a row from a CSV file
  2. Fetch the individual post data from each row in the CSV
  3. Extract relevant data from the post
  4. Save that data to a new CSV file... a file unique to each post that we're scraping

Step 1: Create Simple Reddit Post Data Parser

Here is our parsing function for posts. We're once again retrieving json blobs and extracting important information from them. This function takes the permalink from post objects we created earlier in our crawler.

We're not ready to run this code yet, we need to be able to read the CSV file we created earlier. If we can't read the file, our scraper won't know which posts to process.

def process_post(post_object, location="us", retries=3):
tries = 0
success = False

permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"

link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")


while tries <= retries and not success:
driver = webdriver.Chrome(options=options)

try:
driver.get(r_url)
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)

comments_list = comments[1]["data"]["children"]

for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = {
"name": data["author"],
"body": data["body"],
"upvotes": data["ups"]
}
print(f"Comment: {comment_data}")
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1

finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")

As long as our comment data comes back in the form of a list, we can then go through and parse the comments. If our comment["kind"] is not "more", we assume that these are comments we want to process.

We pull the author, body, and upvotes for each individual comment. If someone wants to look at this data in a large scope, they can then compare accurately to see which types of comments get the best reactions from people.


Step 2: Loading URLs To Scrape

In order to use the parsing function we just created, we need to read the data from our CSV. To do this, we'll use csv.DictReader(), which allows us to read individual rows from the CSV file. We'll call process_post() on each row we read from the file.

Here is the full code example that reads rows from the CSV file and processes them. We have an additional function, process_posts(). It uses a for loop as just a placeholder for now, but later on, this function will be rewritten to use multithreading.

from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor

user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")

proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

class DataPipeline:

def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False

def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return

keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)

if not file_exists:
writer.writeheader()

for item in data_to_save:
writer.writerow(asdict(item))

self.csv_file_open = False

def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False

def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()

def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()

def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url

#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)

if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]

article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)

data_pipeline.add_data(article_data)



else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()

def process_post(post_object, location="us", retries=3):
tries = 0
success = False

permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"

link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")


while tries <= retries and not success:
driver = webdriver.Chrome(options=options)

try:
driver.get(r_url)
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)

comments_list = comments[1]["data"]["children"]

for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = {
"name": data["author"],
"body": data["body"],
"upvotes": data["ups"]
}
print(f"Comment: {comment_data}")
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1

finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")


#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us"):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))
for row in reader:
process_post(row)


########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]
BATCH_SIZE = 10

AGGREGATED_FEEDS = []

for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")

for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file)

In the code above, we call process_posts() to read all the data from a Subreddit CSV. This function runs process_post() on each individual post so we can extract important comment data from the post.


Step 3: Storing the Scraped Data

We've already done most of the work as far as data storage. We just need one new class, CommentData. Similar to SearchData, the purpose of this class is to simply hold the data that we want to scrape.

Once we've got CommentData, we pass it straight into a DataPipeline.

@dataclass
class CommentData:
name: str = ""
body: str = ""
upvotes: int = 0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

Here is process_post() rewritten to store our data.

def process_post(post_object, location="us", retries=3):
tries = 0
success = False

permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"

link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")

comment_pipeline = DataPipeline(csv_filename=f"{filename}.csv")

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(r_url)
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)

comments_list = comments[1]["data"]["children"]

for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = CommentData(
name=data["author"],
body=data["body"],
upvotes=data["ups"]
)
comment_pipeline.add_data(comment_data)
comment_pipeline.close_pipeline()
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1

finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")

In addition to some better error handling, this function now opens a DataPipeline of its own. Each post gets its own DataPipeline so we can safely store the comment data efficiently. This code might get you blocked.

Selenium is faster than any human could possibly be and Reddit will notice abonormalities. After we add concurrency, we're going to add proxy support into our scraper as well.


Step 4: Adding Concurrency

To add concurrency, we're going to use ThreadPoolExecutor. This allows us to open a new pool of threads with however many max_workers we want to specifiy.

This is actually going to increase our likelihood of getting blocked, so adding proxy support in the next section is super important! The reason it increases our likelihood of getting blocked is simple. Our scraper was already really fast. Now it's exponentially faster!

Here is our new process_posts():

#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us", retries=3):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))

with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(process_post, reader, [location] * len(reader), [retries] * len(reader))

executor.map() takes process_post() as its first argument and then it passes all of our other arguments into it as lists. This opens up a new thread to process each post and save its data to its own individual CSV file.

For instance if we have an article in the CSV generated earlier called "Headline Here". We'll now have a separate CSV file specifically for comments and metadata from the "Headline Here" article, and we'll have it fast.


Step 5: Bypassing Anti-Bots

We already created our function for proxied urls earlier. To add a proxy to process_post(), we only need to change one line: driver.get(get_scrapeops_url(r_url, location=location)).

We once again pass get_scrapeops_url() directly into driver.get() so our scraper navigates directly to the proxied url.

Here is our final Python script that makes full use of both the crawler and scraper.

from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor

user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")

proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

@dataclass
class CommentData:
name: str = ""
body: str = ""
upvotes: int = 0

def __post_init__(self):
self.check_string_fields()

def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())

class DataPipeline:

def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False

def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return

keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)

if not file_exists:
writer.writeheader()

for item in data_to_save:
writer.writerow(asdict(item))

self.csv_file_open = False

def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False

def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()

def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()

def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url

#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)

if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]

article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)

data_pipeline.add_data(article_data)



else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()

def process_post(post_object, location="us", retries=3):
tries = 0
success = False

permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"

link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")

comment_pipeline = DataPipeline(csv_filename=f"{filename}.csv")

while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(get_scrapeops_url(r_url, location=location))
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)

comments_list = comments[1]["data"]["children"]

for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = CommentData(
name=data["author"],
body=data["body"],
upvotes=data["ups"]
)
comment_pipeline.add_data(comment_data)
comment_pipeline.close_pipeline()
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1

finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")


#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us", retries=3):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))

with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(process_post, reader, [location] * len(reader), [retries] * len(reader))


########### MAIN FUNCTION #############

if __name__ == "__main__":

FEEDS = ["news"]
BATCH_SIZE = 10
MAX_THREADS = 11

AGGREGATED_FEEDS = []

for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")

for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file, max_workers=MAX_THREADS)


Step 6: Production Run

Take a look at the constants in this block:

if __name__ == "__main__":

FEEDS = ["news"]
LOCATION = "us"
BATCH_SIZE = 100
MAX_THREADS = 11

AGGREGATED_FEEDS = []

for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")

for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file, location=LOCATION, max_workers=MAX_THREADS)

To change your output, you can change any of these constants. If you want 4 threads, change MAX_THREADS to 4. If you want a BATCH_SIZE of 10, change it to 10. If you'd like to scrape a different Subreddit, just add it to the FEEDS list.

In the production run, we generated 100 CSV files all full of processed comments and metadata. It took 1 minute and 57 seconds to create our article list and generate all 100 of the reports.

Benchmarks


When scraping, always pay attention to the Terms of Service and robots.txt. You can view Reddit's Terms here. You can view their robots.txt here. Reddit reserves the right to block, ban, or delete your account if they believe you are responsible for malicious activity.

It's typically legal to collect any public data. Public data is data that is not protected by a login. If you don't have to login, you are generally alright to scrape the data.

If you have concerns or aren't sure whether it's legal to scrape the data you're after, consult an attorney. Attorneys are best equipped to give you legal advice on the data you're scraping.


Conclusion

You've made it to the end!!! Go build something! You now know how to extract JSON data using Selenium and you have a solid grasp on how to retrieve specific items from JSON blobs. You understand that scraping Reddit requries a crawler to gather a list of posts as well as an individual post scraper for gathering specific data about each post.

Here are links to software docs used in this article:


More Python Web Scraping Guides

Wanna learn more? Here at ScrapeOps, we have loads of resources for you to learn from. Take a look at some of these other ScrapeOps Guides.