from selenium import webdriver
from selenium.webdriver.common.keys import Keys
import time
from selenium.webdriver.common.by import By
from io import BytesIO
import requests
import urllib
from PIL import Image
chromedriver = "C:\PythonProjects\chromedriver.exe"
driver = webdriver.Chrome(chromedriver)
driver.implicitly_wait(1)
img_name = "plastic bottle" # 검색, 저장할 이미지 이름
label = 0 # 저장할 라벨
# ['metal', 'glass', 'styrofoam', 'paper', 'plastic']
driver.get('https://www.google.com/search?q='+img_name+'&sxsrf=ALiCzsaSMyAtWBXh81y0i9by3fq_uucJBg:1668536024948&source=lnms&tbm=isch&sa=X&ved=2ahUKEwjZjr-I5bD7AhXVc94KHU3xAxgQ_AUoAXoECAEQAw&cshid=1668536035568388&biw=1536&bih=746&dpr=1.25')
# 구글 검색 이미지 링크
save_path = "C:\PythonProjects\output/" # 이미지 저장 위치
element = driver.find_element(By.TAG_NAME, "body")
def find_coord(img_path):
img = Image.open(img_path).convert("L")
px = img.load()
w = img.size[0]
h = img.size[1]
try:
for i in range(w):
for j in range(h):
if px[i, j] != 255:
x1 = i
raise Exception
except:
pass
try:
for i in range(w-1, -1, -1):
for j in range(h):
if px[i, j] != 255:
x2 = i + 1
raise Exception
except:
pass
try:
for i in range(h):
for j in range(w):
if px[j, i] != 255:
y1 = i
raise Exception
except:
pass
try:
for i in range(h-1, -1, -1):
for j in range(w):
if px[j, i] != 255:
y2 = i
raise Exception
except:
pass
return [[w, h], [x1, x2, y1, y2]]
def convert(size, coord_list):
dw = 1. / size[0]
dh = 1. / size[1]
x = (coord_list[0] + coord_list[1]) / 2.0
y = (coord_list[2] + coord_list[3]) / 2.0
w = coord_list[1] - coord_list[0]
h = coord_list[3] - coord_list[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (x, y, w, h)
while True:
for _ in range(100): # 뭔가 중간에 끝나는 것 같으면 100씩 올려봅시다
element.send_keys(Keys.PAGE_DOWN)
time.sleep(0.1)
try:
driver.find_element(By.XPATH, "/html/body/div[2]/c-wiz/div[3]/div[1]/div/div/div/div/div[1]/div[2]/div[2]/input").click()
except:
break
links = []
images = driver.find_elements(By.CSS_SELECTOR, "img.rg_i.Q4LuWd")
for image in images:
if image.get_attribute("src") != None and 'https' in str(image.get_attribute("src")):
links.append(image.get_attribute("src"))
count = 0
for i in links:
res = requests.get(i)
request_get_img = Image.open(BytesIO(res.content))
w = request_get_img.size[0]
h = request_get_img.size[1]
temp_img = request_get_img.convert("L")
px = temp_img.load()
if px[0, 0] == 255 & px[w-1, 0] == 255 & px[0, h-1] == 255 & px[w-1, h-1] == 255:
urllib.request.urlretrieve(i, str(save_path) + str(img_name) + str(count) + ".jpg")
count += 1
else:
pass
driver.close()
for i in range(count):
coord_list = find_coord(str(save_path) + str(img_name) + str(i) + ".jpg")
output = list(convert(coord_list[0], coord_list[1]))
output.insert(0, label)
result = ' '.join(map(str, output))
output_file_name = (str(save_path) + str(img_name) + str(i)) + '.txt'
save_txt = open(output_file_name, 'w')
save_txt.write(result)
save_txt.close()