import re
from yarl import URL
from datetime import timedelta
from backup.logger import getLogger
from backup.config import Setting, Config
from backup.time import Time
from backup.creds import KEY_CLIENT_SECRET, KEY_CLIENT_ID, KEY_ACCESS_TOKEN, KEY_TOKEN_EXPIRY
from aiohttp.web import (HTTPBadRequest, HTTPNotFound,
HTTPUnauthorized, Request, Response, delete, get,
json_response, patch, post, put, HTTPSeeOther)
from injector import inject, singleton
from .base_server import BaseServer, bytesPattern, intPattern
from .ports import Ports
from typing import Any, Dict
from asyncio import Event
from backup.creds import Creds
logger = getLogger(__name__)
mimeTypeQueryPattern = re.compile("^mimeType='.*'$")
parentsQueryPattern = re.compile("^'.*' in parents$")
resumeBytesPattern = re.compile("^bytes \\*/\\d+$")
URL_MATCH_DRIVE_API = "^.*drive.*$"
URL_MATCH_UPLOAD = "^/upload/drive/v3/files/$"
URL_MATCH_UPLOAD_PROGRESS = "^/upload/drive/v3/files/progress/.*$"
URL_MATCH_CREATE = "^/upload/drive/v3/files/progress/.*$"
URL_MATCH_FILE = "^/drive/v3/files/.*$"
URL_MATCH_DEVICE_CODE = "^/device/code$"
URL_MATCH_TOKEN = "^/token$"
@singleton
class SimulatedGoogle(BaseServer):
@inject
def __init__(self, config: Config, time: Time, ports: Ports):
self._time = time
self.config = config
# auth state
self._custom_drive_client_id = self.generateId(5)
self._custom_drive_client_secret = self.generateId(5)
self._custom_drive_client_expiration = None
self._drive_auth_code = "drive_auth_code"
self._port = ports.server
self._auth_token = ""
self._refresh_token = "test_refresh_token"
self._client_id_hack = None
# Drive item states
self.items = {}
self.lostPermission = []
self.space_available = 5 * 1024 * 1024 * 1024
self.usage = 0
# Upload state information
self._upload_info: Dict[str, Any] = {}
self.chunks = []
self._upload_chunk_wait = Event()
self._upload_chunk_trigger = Event()
self._current_chunk = 1
self._waitOnChunk = 0
self.device_auth_params = {}
self._device_code_accepted = None
def setDriveSpaceAvailable(self, bytes_available):
self.space_available = bytes_available
def generateNewAccessToken(self):
new_token = self.generateId(20)
self._auth_token = new_token
def generateNewRefreshToken(self):
new_token = self.generateId(20)
self._refresh_token = new_token
def expireCreds(self):
self.generateNewAccessToken()
self.generateNewRefreshToken()
def expireRefreshToken(self):
self.generateNewRefreshToken()
def resetDriveAuth(self):
self.expireCreds()
self.config.override(Setting.DEFAULT_DRIVE_CLIENT_ID, self.generateId(5))
self.config.override(Setting.DEFAULT_DRIVE_CLIENT_SECRET, self.generateId(5))
def creds(self):
return Creds(self._time,
id=self.config.get(Setting.DEFAULT_DRIVE_CLIENT_ID),
expiration=self._time.now() + timedelta(hours=1),
access_token=self._auth_token,
refresh_token=self._refresh_token)
def routes(self):
return [
put('/upload/drive/v3/files/progress/{id}', self._uploadProgress),
post('/upload/drive/v3/files/', self._upload),
post('/drive/v3/files/', self._create),
get('/drive/v3/files/', self._query),
delete('/drive/v3/files/{id}/', self._delete),
patch('/drive/v3/files/{id}/', self._update),
get('/drive/v3/files/{id}/', self._get),
post('/oauth2/v4/token', self._oauth2Token),
get('/o/oauth2/v2/auth', self._oAuth2Authorize),
get('/drive/customcreds', self._getCustomCred),
get('/drive/v3/about', self._driveAbout),
post('/device/code', self._deviceCode),
get('/device', self._device),
get('/debug/google', self._debug),
post('/token', self._driveToken),
]
async def _debug(self, request: Request):
return json_response({
"custom_drive_client_id": self._custom_drive_client_id,
"custom_drive_client_secret": self._custom_drive_client_secret,
"device_auth_params": self.device_auth_params
})
async def _checkDriveHeaders(self, request: Request):
if request.headers.get("Authorization", "") != "Bearer " + self._auth_token:
raise HTTPUnauthorized()
async def _deviceCode(self, request: Request):
params = await request.post()
client_id = params['client_id']
scope = params['scope']
if client_id != self._custom_drive_client_id or scope != 'https://www.googleapis.com/auth/drive.file':
raise HTTPUnauthorized()
self.device_auth_params = {
'device_code': self.generateId(10),
'expires_in': 60,
'interval': 1,
'user_code': self.generateId(8),
'verification_url': str(URL("http://localhost").with_port(self._port).with_path("device"))
}
self._device_code_accepted = None
return json_response(self.device_auth_params)
async def _device(self, request: Request):
code = request.query.get('code')
if code:
if self.device_auth_params.get('user_code', "dfsdfsdfsdfs") == code:
body = "Accepted"
self._device_code_accepted = True
self.generateNewRefreshToken()
self.generateNewAccessToken()
else:
body = "Wrong code"
else:
body = """
Simulated Drive Device Authorization
Enter the device code provided below
"""
resp = Response(body=body, content_type="text/html")
return resp
async def _oAuth2Authorize(self, request: Request):
query = request.query
if query.get('client_id') != self.config.get(Setting.DEFAULT_DRIVE_CLIENT_ID) and query.get('client_id') != self._custom_drive_client_id:
raise HTTPUnauthorized()
if query.get('scope') != 'https://www.googleapis.com/auth/drive.file':
raise HTTPUnauthorized()
if query.get('response_type') != 'code':
raise HTTPUnauthorized()
if query.get('include_granted_scopes') != 'true':
raise HTTPUnauthorized()
if query.get('access_type') != 'offline':
raise HTTPUnauthorized()
if 'state' not in query:
raise HTTPUnauthorized()
if 'redirect_uri' not in query:
raise HTTPUnauthorized()
if query.get('prompt') != 'consent':
raise HTTPUnauthorized()
if query.get('redirect_uri') == 'urn:ietf:wg:oauth:2.0:oob':
return json_response({"code": self._drive_auth_code})
url = URL(query.get('redirect_uri')).with_query({'code': self._drive_auth_code, 'state': query.get('state')})
raise HTTPSeeOther(str(url))
async def _getCustomCred(self, request: Request):
return json_response({
"client_id": self._custom_drive_client_id,
"client_secret": self._custom_drive_client_secret
})
async def _driveToken(self, request: Request):
data = await request.post()
if not self._checkClientIdandSecret(data.get('client_id'), data.get('client_secret')):
raise HTTPUnauthorized()
if data.get('grant_type') == 'authorization_code':
if data.get('redirect_uri') not in ["http://localhost:{}/drive/authorize".format(self._port), 'urn:ietf:wg:oauth:2.0:oob']:
raise HTTPUnauthorized()
if data.get('code') != self._drive_auth_code:
raise HTTPUnauthorized()
elif data.get('grant_type') == 'urn:ietf:params:oauth:grant-type:device_code':
if data.get('device_code') != self.device_auth_params['device_code']:
raise HTTPUnauthorized()
if self._device_code_accepted is None:
return json_response({
"error": "authorization_pending",
"error_description": "Precondition Required"
}, status=428)
elif self._device_code_accepted is False:
raise HTTPUnauthorized()
else:
raise HTTPBadRequest()
self.generateNewRefreshToken()
resp = {
'access_token': self._auth_token,
'refresh_token': self._refresh_token,
KEY_CLIENT_ID: data.get('client_id'),
KEY_CLIENT_SECRET: self.config.get(Setting.DEFAULT_DRIVE_CLIENT_SECRET),
KEY_TOKEN_EXPIRY: self.timeToRfc3339String(self._time.now()),
}
if self._custom_drive_client_expiration is not None:
resp[KEY_TOKEN_EXPIRY] = self.timeToRfc3339String(self._custom_drive_client_expiration)
return json_response(resp)
def _checkClientIdandSecret(self, client_id: str, client_secret: str) -> bool:
if self._custom_drive_client_id == client_id and self._custom_drive_client_secret == client_secret:
return True
if client_id == self.config.get(Setting.DEFAULT_DRIVE_CLIENT_ID) == client_id and client_secret == self.config.get(Setting.DEFAULT_DRIVE_CLIENT_SECRET):
return True
if self._client_id_hack is not None:
if client_id == self._client_id_hack and client_secret == self.config.get(Setting.DEFAULT_DRIVE_CLIENT_SECRET):
return True
return False
async def _oauth2Token(self, request: Request):
params = await request.post()
if not self._checkClientIdandSecret(params['client_id'], params['client_secret']):
raise HTTPUnauthorized()
if params['refresh_token'] != self._refresh_token:
raise HTTPUnauthorized()
if params['grant_type'] == 'refresh_token':
self.generateNewAccessToken()
return json_response({
'access_token': self._auth_token,
'expires_in': 3600,
'token_type': 'doesn\'t matter'
})
elif params['grant_type'] == 'urn:ietf:params:oauth:grant-type:device_code':
if params['device_code'] != self.device_auth_params['device_code']:
raise HTTPUnauthorized()
if not self._device_code_accepted:
return json_response({
"error": "authorization_pending",
"error_description": "Precondition Required"
}, status=428)
return json_response({
'access_token': self._auth_token,
'expires_in': 3600,
'token_type': 'doesn\'t matter'
})
else:
raise HTTPUnauthorized()
def filter_fields(self, item: Dict[str, Any], fields) -> Dict[str, Any]:
ret = {}
for field in fields:
if field in item:
ret[field] = item[field]
return ret
def parseFields(self, source: str):
fields = []
for field in source.split(","):
if field.startswith("files("):
fields.append(field[6:])
elif field.endswith(")"):
fields.append(field[:-1])
else:
fields.append(field)
return fields
def formatItem(self, base, id):
caps = base.get('capabilites', {})
if 'capabilities' not in base:
base['capabilities'] = caps
if 'canAddChildren' not in caps:
caps['canAddChildren'] = True
if 'canListChildren' not in caps:
caps['canListChildren'] = True
if 'canDeleteChildren' not in caps:
caps['canDeleteChildren'] = True
if 'canTrashChildren' not in caps:
caps['canTrashChildren'] = True
if 'canTrash' not in caps:
caps['canTrash'] = True
if 'canDelete' not in caps:
caps['canDelete'] = True
for parent in base.get("parents", []):
parent_item = self.items[parent]
# This simulates a very simply shared drive permissions structure
if parent_item.get("driveId", None) is not None:
base["driveId"] = parent_item["driveId"]
base["capabilities"] = parent_item["capabilities"]
base['trashed'] = False
base['id'] = id
base['modifiedTime'] = self.timeToRfc3339String(self._time.now())
return base
async def _get(self, request: Request):
id = request.match_info.get('id')
await self._checkDriveHeaders(request)
if id not in self.items:
raise HTTPNotFound()
if id in self.lostPermission:
return Response(
status=403,
content_type="application/json",
text='{"error": {"errors": [{"reason": "forbidden"}]}}')
request_type = request.query.get("alt", "metadata")
if request_type == "media":
# return bytes
item = self.items[id]
if 'bytes' not in item:
raise HTTPBadRequest()
return self.serve_bytes(request, item['bytes'], include_length=False)
else:
fields = request.query.get("fields", "id").split(",")
return json_response(self.filter_fields(self.items[id], fields))
async def _update(self, request: Request):
id = request.match_info.get('id')
await self._checkDriveHeaders(request)
if id not in self.items:
return HTTPNotFound
update = await request.json()
for key in update:
if key in self.items[id] and isinstance(self.items[id][key], dict):
self.items[id][key].update(update[key])
else:
self.items[id][key] = update[key]
return Response()
async def _driveAbout(self, request: Request):
return json_response({
'storageQuota': {
'usage': self.usage,
'limit': self.space_available
},
'user': {
'emailAddress': "testing@no.where"
}
})
async def _delete(self, request: Request):
id = request.match_info.get('id')
await self._checkDriveHeaders(request)
if id not in self.items:
raise HTTPNotFound()
del self.items[id]
return Response()
async def _query(self, request: Request):
await self._checkDriveHeaders(request)
query: str = request.query.get("q", "")
fields = self.parseFields(request.query.get('fields', 'id'))
if mimeTypeQueryPattern.match(query):
ret = []
mimeType = query[len("mimeType='"):-1]
for item in self.items.values():
if item.get('mimeType', '') == mimeType:
ret.append(self.filter_fields(item, fields))
return json_response({'files': ret})
elif parentsQueryPattern.match(query):
ret = []
parent = query[1:-len("' in parents")]
if parent not in self.items:
raise HTTPNotFound()
if parent in self.lostPermission:
return Response(
status=403,
content_type="application/json",
text='{"error": {"errors": [{"reason": "forbidden"}]}}')
for item in self.items.values():
if parent in item.get('parents', []):
ret.append(self.filter_fields(item, fields))
return json_response({'files': ret})
elif len(query) == 0:
ret = []
for item in self.items.values():
ret.append(self.filter_fields(item, fields))
return json_response({'files': ret})
else:
raise HTTPBadRequest
async def _create(self, request: Request):
await self._checkDriveHeaders(request)
item = self.formatItem(await request.json(), self.generateId(30))
self.items[item['id']] = item
return json_response({'id': item['id']})
async def _upload(self, request: Request):
logger.info("Drive start upload request")
await self._checkDriveHeaders(request)
if request.query.get('uploadType') != 'resumable':
raise HTTPBadRequest()
mimeType = request.headers.get('X-Upload-Content-Type', None)
if mimeType is None:
raise HTTPBadRequest()
size = int(request.headers.get('X-Upload-Content-Length', -1))
if size < 0:
raise HTTPBadRequest()
total_size = 0
for item in self.items.values():
total_size += item.get('size', 0)
total_size += size
if total_size > self.space_available:
return json_response({
"error": {
"errors": [
{"reason": "storageQuotaExceeded"}
]
}
}, status=400)
metadata = await request.json()
id = self.generateId()
# Validate parents
if 'parents' in metadata:
for parent in metadata['parents']:
if parent not in self.items:
raise HTTPNotFound()
if parent in self.lostPermission:
return Response(status=403, content_type="application/json", text='{"error": {"errors": [{"reason": "forbidden"}]}}')
self._upload_info['size'] = size
self._upload_info['mime'] = mimeType
self._upload_info['item'] = self.formatItem(metadata, id)
self._upload_info['id'] = id
self._upload_info['next_start'] = 0
metadata['bytes'] = bytearray()
metadata['size'] = size
resp = Response()
resp.headers['Location'] = "http://localhost:" + \
str(self._port) + "/upload/drive/v3/files/progress/" + id
return resp
async def _uploadProgress(self, request: Request):
if self._waitOnChunk > 0:
if self._current_chunk == self._waitOnChunk:
self._upload_chunk_trigger.set()
await self._upload_chunk_wait.wait()
else:
self._current_chunk += 1
id = request.match_info.get('id')
await self._checkDriveHeaders(request)
if self._upload_info.get('id', "") != id:
raise HTTPBadRequest()
chunk_size = int(request.headers['Content-Length'])
info = request.headers['Content-Range']
if resumeBytesPattern.match(info):
resp = Response(status=308)
if self._upload_info['next_start'] != 0:
resp.headers['Range'] = "bytes=0-{0}".format(self._upload_info['next_start'] - 1)
return resp
if not bytesPattern.match(info):
raise HTTPBadRequest()
numbers = intPattern.findall(info)
start = int(numbers[0])
end = int(numbers[1])
total = int(numbers[2])
if total != self._upload_info['size']:
raise HTTPBadRequest()
if start != self._upload_info['next_start']:
raise HTTPBadRequest()
if not (end == total - 1 or chunk_size % (256 * 1024) == 0):
raise HTTPBadRequest()
if end > total - 1:
raise HTTPBadRequest()
# get the chunk
received_bytes = await self.readAll(request)
# validate the chunk
if len(received_bytes) != chunk_size:
raise HTTPBadRequest()
if len(received_bytes) != end - start + 1:
raise HTTPBadRequest()
self._upload_info['item']['bytes'].extend(received_bytes)
if len(self._upload_info['item']['bytes']) != end + 1:
raise HTTPBadRequest()
self.usage += len(received_bytes)
self.chunks.append(len(received_bytes))
if end == total - 1:
# upload is complete, so create the item
completed = self.formatItem(self._upload_info['item'], self._upload_info['id'])
self.items[completed['id']] = completed
return json_response({"id": completed['id']})
else:
# Return an incomplete response
# For some reason, the tests like to stop right here
resp = Response(status=308)
self._upload_info['next_start'] = end + 1
resp.headers['Range'] = "bytes=0-{0}".format(end)
return resp