How to Scrape Reddit With Selenium
Reddit is a platform where people can share news, opinions, ideas and all sorts of other things. In fact, Reddit has a reputation as the "Front Page of the Internet".
Today, we're going to learn how to scrape data from Reddit:
- TLDR: How to Scrape Reddit
- How To Architect Our Scraper
- Understanding How To Scrape Reddit
- Setting Up Our Reddit Scraper
- Build A Reddit Crawler
- Build A Reddit Post Scraper
- Legal and Ethical Considerations
- Conclusion
- More Cool Articles
Need help scraping the web?
Then check out ScrapeOps, the complete toolkit for web scraping.
TLDR: How to Scrape Reddit
- We can actuallly fetch a batch of posts from Reddit by adding
.json
to the end of our url. - In the example below, we've got a production ready Reddit scraper:
from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")
proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
@dataclass
class CommentData:
name: str = ""
body: str = ""
upvotes: int = 0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
class DataPipeline:
def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False
def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return
keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)
if not file_exists:
writer.writeheader()
for item in data_to_save:
writer.writerow(asdict(item))
self.csv_file_open = False
def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False
def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()
def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()
def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url
#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)
if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]
article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)
data_pipeline.add_data(article_data)
else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()
def process_post(post_object, location="us", retries=3):
tries = 0
success = False
permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"
link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")
comment_pipeline = DataPipeline(csv_filename=f"{filename}.csv")
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(get_scrapeops_url(r_url, location=location))
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)
comments_list = comments[1]["data"]["children"]
for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = CommentData(
name=data["author"],
body=data["body"],
upvotes=data["ups"]
)
comment_pipeline.add_data(comment_data)
comment_pipeline.close_pipeline()
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1
finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")
#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us", retries=3):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(process_post, reader, [location] * len(reader), [retries] * len(reader))
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
BATCH_SIZE = 100
MAX_THREADS = 11
AGGREGATED_FEEDS = []
for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")
for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file, max_workers=MAX_THREADS)
This scraper is production ready.
- To customize your feeds, simply add the subreddits you want into the
FEEDS
array. - To change add or remove threads, change the
MAX_THREADS
constant (it is set to run with 11 by default) If your CPU takes advantage of hyperthreading, you can actually use more threads than you have cores. - For instance, the code this article was written using an i3 6 core CPU with hyperthreading. With hyperthreading, the CPU is capable of running 12 threads.
- We chose to use 11, so the rest of the PC can run smoothly on one thread while the scraper operates. If you want the top 100 posts instead of 10, change
BATCH_SIZE
to 100.
How To How To Architect Our Reddit Scraper
Our Reddit scraper needs to be able to do the following things:
- Fetch a batch of Reddit posts
- Retrieve data about each individual post
- Clean Our data
- Save our data to a CSV file
We're going to go through the process of building a scraper that does all these things from start to finish using Selenium. While Selenium isn't exactly lightweight, it is highly optimized. Selenium is equipped with more than enough to handle the simple JSON parsing with power and speed.
We can scrape Reddit with Selenium and gather a ton of data...FAST.
Understanding How To Scrape Reddit
Step 1: How To Request Reddit Pages
When you lookup a feed on Reddit, you're given content in small batches. Whenever you scroll down, the site automatically fetches more content for your to read.
Infinite scrollers like this often present challenges when trying to scrape. While Selelnium gives us the ability to scroll the page, there's actually a faster way. We'll dig into that soon enough. Take a look at the image below.
As you can see, the url looks like this:
https://www.reddit.com/r/news/?rdt=51809
By tweaking just a couple of things, we can completely change the result.
Lets change the url to:
https://www.reddit.com/r/news.json
Take a look at the result below.
By simply changing /?rdt=51809
to .json
, we've turn Reddit into a full blown feed!
Step 2: How To Extract Data From Reddit Feeds
Since our data is stored in the form of JSON, to retrieve it, we just need to know how to handle simple dictionaries in Python. Python's dictionaries are key-value pairs just like JSON objects, there is one main difference between the two types: JSON objects don't natively have quotation marks in the keys.
Take the following example below:
Dictionary
{"name": "John Doe", "age": 30}
The same object as traditional JSON would be:
JSON
{name: "John Doe", age: 30}
In order to index a dict
in Python we simply use its keys. Our entire list of content comes in our resp
or response object. With a standard HTTP client like Python Requests, we would simply use resp.json()
.
With Selenium, we get the page and then we use json.loads()
to convert the text string into a dict
. To access each article in the list, all we have to do is change our index:
- First Article:
resp["data"]["children"][0]
- Second Article:
resp["data"]["children"][1]
- Third Article:
resp["data"]["children"][2]
We can follow this method all the way up to last last child and we'll be finished collecting articles.
Step 3: How To Control Pagination
Think back to the link we looked at previously,
https://www.reddit.com/r/news.json
We can add a limit
parameter to this url for finer control of our results.
If we want 100 news results, our url would look like this:
https://www.reddit.com/r/news.json?limit=100
This doesn't give us actual pages to sort through, but we do get long lists of results that we can control. All we have to do is pass the text into json.loads()
.
Setting Up Our Reddit Scraper Project
Let's get started with our project. First, we'll make a new project folder. I'll call mine reddit-scraper
.
mkdir reddit-scraper
Next, we'll make a new virtual environment.
python -m venv venv
Activate the new virtual environment.
source venv/bin/activate
We only have two dependencies to install. Make sure you have the lastest version of Chromedriver installed. Then you can install Selenium using pip.
pip install selenium
If you have any issues with installing Chromedriver and have an executable path error, take a look at our article about how to tackle this problem.
Build A Reddit Crawler
When scraping Reddit, we actually need to build two scrapers.
-
We need a crawler to identify all of the different posts. This crawler will fetch lists of posts, extract data from them, and save that data to a CSV file.
-
After our data is saved into the CSV file, it will be read by an individual post scraper. The post scraper will go and fetch individual posts along with commments and some other metadata.
Step 1: Create A Reddit Data Parser
First, we need a simple parser for our Reddit data.
- In the script below, we have one function,
get_posts()
. This function takes one argument,feed
and a kwarg,retries
which is set to 3 by default. - While we still have retries left, we
try
to fetch the json feed from Reddit. If we run out of retries, weraise
anException
and allow the crawler to crash. - Our json data comes nested inside of a
pre
tag, so we simply use Selenium to find thepre
tag before loading our text. - In our json response, we have an array of json objects called
children
. Each item in this array represents an individual Reddit post. - Each post contains a
data
field. From thatdata
field, we pull thetitle
,author_fullname
,permalink
, andupvote_ratio
- Later on, these items will make up the data we wish to save from our search, but we'll just print them for now.
rom selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")
proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
#get posts from a subreddit
def get_posts(feed, retries=3):
tries = 0
success = False
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json"
driver.get(url)
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)
if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]
#extract individual fields from the site data
name = data["title"]
author = data["author_fullname"]
permalink = data["permalink"]
upvote_ratio = data["upvote_ratio"]
#print the extracted data
print(f"Name: {name}")
print(f"Author: {author}")
print(f"Permalink: {permalink}")
print(f"Upvote Ratio: {upvote_ratio}")
else:
logger.warning(f"Failed response: {resp.status_code}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
for feed in FEEDS:
get_posts(feed)
If you run this script, you should get an output similar to this:
As you can see in the image above, we extract the following from each post:
- Name
- Author
- Permalink
- Upvote Ratio
Step 2: Add Pagination
Now that we're getting results, we need finer control over our results. If we want 100 results, we should get 100. If we only want 10 results, we should get 10.
We can accomplish this by adding the limit
parameter to our url. Let's refactor our get_posts()
function to take an additional keyword, limit
.
Taking our limit into account, our url will now look like this:
https://www.reddit.com/r/{feed}.json?limit={limit}
Here is the updated script:
from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")
proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
#get posts from a subreddit
def get_posts(feed, limit=100, retries=3):
tries = 0
success = False
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(url)
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)
if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]
#extract individual fields from the site data
name = data["title"]
author = data["author_fullname"]
permalink = data["permalink"]
upvote_ratio = data["upvote_ratio"]
#print the extracted data
print(f"Name: {name}")
print(f"Author: {author}")
print(f"Permalink: {permalink}")
print(f"Upvote Ratio: {upvote_ratio}")
else:
logger.warning(f"Failed response: {resp.status_code}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
BATCH_SIZE = 2
for feed in FEEDS:
get_posts(feed, limit=BATCH_SIZE)
- In the code above, we now add a
limit
parameter toget_posts()
. - We also declare a new constant in our main,
BATCH_SIZE
. We pass ourBATCH_SIZE
intoget_posts()
to control the size of our results.
Feel free to try changing the batch size and examining your results. limit
is incredibly important. We don't want to scrape through hundreds of results if we only need 10... and we certainly don't want to try scraping hundreds of results when we're only limited to 10!
Step 3: Storing the Scraped Data
Now that we're retrieving the proper data, we need to be able to store that data. To store this data, we're going to add a SearchData
class and a DataPipeline
class as well.
SearchData
is going to be relatively simple, all it's going to do is hold individual data. DataPipeline
will be doing the real heavy lifting.
Our DataPipeline
class will be doing all the work of removing duplicates and saving our SearchData
objects to CSV.
In this example, we utilize SearchData
to hold the data we've extracted and we then pass it into the DataPipeline
.
from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")
proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
class DataPipeline:
def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False
def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return
keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)
if not file_exists:
writer.writeheader()
for item in data_to_save:
writer.writerow(asdict(item))
self.csv_file_open = False
def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False
def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()
def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()
#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None):
tries = 0
success = False
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(url)
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)
if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]
article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)
data_pipeline.add_data(article_data)
else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
BATCH_SIZE = 2
for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
- In the code example above, we create a
DataPipeline
and pass it intoget_posts()
. - From inside
get_posts()
, we get our post data and turn it into aSearchData
object. - This object then gets passed into the
DataPipeline
which removes duplicates and saves everything to a CSV. - Once we've gone through and processed the posts, we go through and close the pipeline.
- We also replaced
author_fullname
withauthor
. This allows us to see the actual display name of each poster.
Step 4: Bypassing Anti-Bots
Anti-bots are used to detect malicious software. While our crawler is not malicious, we are requesting json data with custom batches and it does make us look a bit abnormal.
In order to prevent from getting blocked, we're going to pass all of these requests through the ScrapeOps Proxy API Aggregator. This API gives us the benefit of rotating IP addresses, and it always selects the best proxy available.
In this code snippet, we create a simple function, get_scrapeops_url()
. This function takes in a regular url and uses simple string formatting to create a proxied url.
def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url
Here is the full code example.
from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")
proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
class DataPipeline:
def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False
def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return
keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)
if not file_exists:
writer.writeheader()
for item in data_to_save:
writer.writerow(asdict(item))
self.csv_file_open = False
def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False
def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()
def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()
def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url
#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)
if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]
article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)
data_pipeline.add_data(article_data)
else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
BATCH_SIZE = 2
for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
In the example above, we create a proxy_url
by passing our url
into get_scrapeops_url()
. We then pass the result directly into driver.get()
so Selenium will navigate to the proxied url instead of the regular one.
Step 5: Production Run
Now that we've got a working crawler, let's give it a production run. I'm going to change our batch size to 100.
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
BATCH_SIZE = 100
for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipline=feed_pipeline)
feed_pipeline.close_pipeline()
Now let's take a look at the output file.
Build A Reddit Post Scraper
Now it's time to build our post scraper. The goal of this scraper is quite simple. It needs to use multithreading to do the following:
- Read a row from a CSV file
- Fetch the individual post data from each row in the CSV
- Extract relevant data from the post
- Save that data to a new CSV file... a file unique to each post that we're scraping
Step 1: Create Simple Reddit Post Data Parser
Here is our parsing function for posts. We're once again retrieving json blobs and extracting important information from them. This function takes the permalink
from post objects we created earlier in our crawler.
We're not ready to run this code yet, we need to be able to read the CSV file we created earlier. If we can't read the file, our scraper won't know which posts to process.
def process_post(post_object, location="us", retries=3):
tries = 0
success = False
permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"
link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(r_url)
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)
comments_list = comments[1]["data"]["children"]
for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = {
"name": data["author"],
"body": data["body"],
"upvotes": data["ups"]
}
print(f"Comment: {comment_data}")
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1
finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")
As long as our comment data comes back in the form of a list, we can then go through and parse the comments. If our comment["kind"]
is not "more"
, we assume that these are comments we want to process.
We pull the author
, body
, and upvotes
for each individual comment. If someone wants to look at this data in a large scope, they can then compare accurately to see which types of comments get the best reactions from people.
Step 2: Loading URLs To Scrape
In order to use the parsing function we just created, we need to read the data from our CSV. To do this, we'll use csv.DictReader()
, which allows us to read individual rows from the CSV file. We'll call process_post()
on each row we read from the file.
Here is the full code example that reads rows from the CSV file and processes them. We have an additional function, process_posts()
. It uses a for
loop as just a placeholder for now, but later on, this function will be rewritten to use multithreading.
from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")
proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
class DataPipeline:
def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False
def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return
keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)
if not file_exists:
writer.writeheader()
for item in data_to_save:
writer.writerow(asdict(item))
self.csv_file_open = False
def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False
def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()
def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()
def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url
#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)
if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]
article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)
data_pipeline.add_data(article_data)
else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()
def process_post(post_object, location="us", retries=3):
tries = 0
success = False
permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"
link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(r_url)
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)
comments_list = comments[1]["data"]["children"]
for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = {
"name": data["author"],
"body": data["body"],
"upvotes": data["ups"]
}
print(f"Comment: {comment_data}")
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1
finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")
#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us"):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))
for row in reader:
process_post(row)
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
BATCH_SIZE = 10
AGGREGATED_FEEDS = []
for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")
for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file)
In the code above, we call process_posts()
to read all the data from a Subreddit CSV. This function runs process_post()
on each individual post so we can extract important comment data from the post.
Step 3: Storing the Scraped Data
We've already done most of the work as far as data storage. We just need one new class, CommentData
. Similar to SearchData
, the purpose of this class is to simply hold the data that we want to scrape.
Once we've got CommentData
, we pass it straight into a DataPipeline
.
@dataclass
class CommentData:
name: str = ""
body: str = ""
upvotes: int = 0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
Here is process_post()
rewritten to store our data.
def process_post(post_object, location="us", retries=3):
tries = 0
success = False
permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"
link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")
comment_pipeline = DataPipeline(csv_filename=f"{filename}.csv")
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(r_url)
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)
comments_list = comments[1]["data"]["children"]
for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = CommentData(
name=data["author"],
body=data["body"],
upvotes=data["ups"]
)
comment_pipeline.add_data(comment_data)
comment_pipeline.close_pipeline()
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1
finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")
In addition to some better error handling, this function now opens a DataPipeline
of its own. Each post gets its own DataPipeline
so we can safely store the comment data efficiently. This code might get you blocked.
Selenium is faster than any human could possibly be and Reddit will notice abonormalities. After we add concurrency, we're going to add proxy support into our scraper as well.
Step 4: Adding Concurrency
To add concurrency, we're going to use ThreadPoolExecutor
. This allows us to open a new pool of threads with however many max_workers
we want to specifiy.
This is actually going to increase our likelihood of getting blocked, so adding proxy support in the next section is super important! The reason it increases our likelihood of getting blocked is simple. Our scraper was already really fast. Now it's exponentially faster!
Here is our new process_posts()
:
#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us", retries=3):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(process_post, reader, [location] * len(reader), [retries] * len(reader))
executor.map()
takes process_post()
as its first argument and then it passes all of our other arguments into it as lists. This opens up a new thread to process each post and save its data to its own individual CSV file.
For instance if we have an article in the CSV generated earlier called "Headline Here". We'll now have a separate CSV file specifically for comments and metadata from the "Headline Here" article, and we'll have it fast.
Step 5: Bypassing Anti-Bots
We already created our function for proxied urls earlier. To add a proxy to process_post()
, we only need to change one line: driver.get(get_scrapeops_url(r_url, location=location))
.
We once again pass get_scrapeops_url()
directly into driver.get()
so our scraper navigates directly to the proxied url.
Here is our final Python script that makes full use of both the crawler and scraper.
from selenium import webdriver
from selenium.webdriver import ChromeOptions
from selenium.webdriver.common.by import By
from urllib.parse import urlencode
import csv, json, time
import logging, os
from dataclasses import dataclass, field, fields, asdict
from concurrent.futures import ThreadPoolExecutor
user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.3'
options = ChromeOptions()
options.add_argument("--headless")
options.add_argument(f"user-agent={user_agent}")
proxy_url = "https://proxy.scrapeops.io/v1/"
API_KEY = "YOUR-SUPER-SECRET-API-KEY"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SearchData:
name: str = ""
author: str = ""
permalink: str = ""
upvote_ratio: float = 0.0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
@dataclass
class CommentData:
name: str = ""
body: str = ""
upvotes: int = 0
def __post_init__(self):
self.check_string_fields()
def check_string_fields(self):
for field in fields(self):
# Check string fields
if isinstance(getattr(self, field.name), str):
# If empty set default text
if getattr(self, field.name) == '':
setattr(self, field.name, f"No {field.name}")
continue
# Strip any trailing spaces, etc.
value = getattr(self, field.name)
setattr(self, field.name, value.strip())
class DataPipeline:
def __init__(self, csv_filename='', storage_queue_limit=50):
self.names_seen = []
self.storage_queue = []
self.storage_queue_limit = storage_queue_limit
self.csv_filename = csv_filename
self.csv_file_open = False
def save_to_csv(self):
self.csv_file_open = True
data_to_save = []
data_to_save.extend(self.storage_queue)
self.storage_queue.clear()
if not data_to_save:
return
keys = [field.name for field in fields(data_to_save[0])]
file_exists = os.path.isfile(self.csv_filename) and os.path.getsize(self.csv_filename) > 0
with open(self.csv_filename, mode='a', newline='', encoding='utf-8') as output_file:
writer = csv.DictWriter(output_file, fieldnames=keys)
if not file_exists:
writer.writeheader()
for item in data_to_save:
writer.writerow(asdict(item))
self.csv_file_open = False
def is_duplicate(self, input_data):
if input_data.name in self.names_seen:
logger.warning(f"Duplicate item found: {input_data.name}. Item dropped.")
return True
self.names_seen.append(input_data.name)
return False
def add_data(self, scraped_data):
if self.is_duplicate(scraped_data) == False:
self.storage_queue.append(scraped_data)
if len(self.storage_queue) >= self.storage_queue_limit and self.csv_file_open == False:
self.save_to_csv()
def close_pipeline(self):
if self.csv_file_open:
time.sleep(3)
if len(self.storage_queue) > 0:
self.save_to_csv()
def get_scrapeops_url(url, location="us"):
payload = {
"api_key": API_KEY,
"url": url,
"country": location
}
proxy_url = "https://proxy.scrapeops.io/v1/?" + urlencode(payload)
return proxy_url
#get posts from a subreddit
def get_posts(feed, limit=100, retries=3, data_pipeline=None, location="us"):
tries = 0
success = False
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
url = f"https://www.reddit.com/r/{feed}.json?limit={limit}"
driver.get(get_scrapeops_url(url, location=location))
json_text = driver.find_element(By.TAG_NAME, "pre").text
resp = json.loads(json_text)
if resp:
success = True
children = resp["data"]["children"]
for child in children:
data = child["data"]
article_data = SearchData(
name=data["title"],
author=data["author"],
permalink=data["permalink"],
upvote_ratio=data["upvote_ratio"]
)
data_pipeline.add_data(article_data)
else:
logger.warning(f"Failed response from server, tries left: {retries-tries}")
raise Exception("Failed to get posts")
except Exception as e:
driver.save_screenshot(f"error-{tries}.png")
logger.warning(f"Exeception, failed to get posts: {e}")
tries += 1
finally:
driver.quit()
def process_post(post_object, location="us", retries=3):
tries = 0
success = False
permalink = post_object["permalink"]
r_url = f"https://www.reddit.com{permalink}.json"
link_array = permalink.split("/")
filename = link_array[-2].replace(" ", "-")
comment_pipeline = DataPipeline(csv_filename=f"{filename}.csv")
while tries <= retries and not success:
driver = webdriver.Chrome(options=options)
try:
driver.get(get_scrapeops_url(r_url, location=location))
comment_data = driver.find_element(By.TAG_NAME, "pre").text
if not comment_data:
raise Exception(f"Failed response: {comment_data.status_code}")
comments = json.loads(comment_data)
comments_list = comments[1]["data"]["children"]
for comment in comments_list:
if comment["kind"] != "more":
data = comment["data"]
comment_data = CommentData(
name=data["author"],
body=data["body"],
upvotes=data["ups"]
)
comment_pipeline.add_data(comment_data)
comment_pipeline.close_pipeline()
success = True
except Exception as e:
logger.warning(f"Failed to retrieve comment:\n{e}")
tries += 1
finally:
driver.quit()
if not success:
raise Exception(f"Max retries exceeded {retries}")
#process a batch of posts
def process_posts(csv_file, max_workers=5, location="us", retries=3):
with open(csv_file, newline="") as csvfile:
reader = list(csv.DictReader(csvfile))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(process_post, reader, [location] * len(reader), [retries] * len(reader))
########### MAIN FUNCTION #############
if __name__ == "__main__":
FEEDS = ["news"]
BATCH_SIZE = 10
MAX_THREADS = 11
AGGREGATED_FEEDS = []
for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipeline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")
for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file, max_workers=MAX_THREADS)
Step 6: Production Run
Take a look at the constants in this block:
if __name__ == "__main__":
FEEDS = ["news"]
LOCATION = "us"
BATCH_SIZE = 100
MAX_THREADS = 11
AGGREGATED_FEEDS = []
for feed in FEEDS:
feed_filename = feed.replace(" ", "-")
feed_pipeline = DataPipeline(csv_filename=f"{feed_filename}.csv")
get_posts(feed, limit=BATCH_SIZE, data_pipline=feed_pipeline)
feed_pipeline.close_pipeline()
AGGREGATED_FEEDS.append(f"{feed_filename}.csv")
for individual_file in AGGREGATED_FEEDS:
process_posts(individual_file, location=LOCATION, max_workers=MAX_THREADS)
To change your output, you can change any of these constants. If you want 4 threads, change MAX_THREADS
to 4. If you want a BATCH_SIZE
of 10, change it to 10. If you'd like to scrape a different Subreddit, just add it to the FEEDS
list.
In the production run, we generated 100 CSV files all full of processed comments and metadata. It took 1 minute and 57 seconds to create our article list and generate all 100 of the reports.
Legal and Ethical Considerations
When scraping, always pay attention to the Terms of Service and robots.txt. You can view Reddit's Terms here. You can view their robots.txt here. Reddit reserves the right to block, ban, or delete your account if they believe you are responsible for malicious activity.
It's typically legal to collect any public data. Public data is data that is not protected by a login. If you don't have to login, you are generally alright to scrape the data.
If you have concerns or aren't sure whether it's legal to scrape the data you're after, consult an attorney. Attorneys are best equipped to give you legal advice on the data you're scraping.
Conclusion
You've made it to the end!!! Go build something! You now know how to extract JSON data using Selenium and you have a solid grasp on how to retrieve specific items from JSON blobs. You understand that scraping Reddit requries a crawler to gather a list of posts as well as an individual post scraper for gathering specific data about each post.
Here are links to software docs used in this article:
More Python Web Scraping Guides
Wanna learn more? Here at ScrapeOps, we have loads of resources for you to learn from. Take a look at some of these other ScrapeOps Guides.