Safe_Share/safeshare/safeshare_app/views/file.py

200 lines
6.6 KiB
Python

import hashlib
import os
import sys
import threading
import uuid
import logging
from urllib.parse import quote
import magic
from django.conf import settings
from django.core.cache import cache
from django.http import HttpResponse
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.views import APIView
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../utils/safeshare_vdb_client")
import client
logger = logging.getLogger(__name__)
class ManageItemsView(APIView):
TIMEOUT = 20 # seconds
def post(self, request):
files = request.FILES.getlist('file')
ttl = request.data.get('ttl') or 259200 # Default TTL is 3 days
try:
ttl = int(ttl)
if ttl <= 0:
return Response({'msg': 'TTL must be a positive integer'}, status=400)
except ValueError:
return Response({'msg': 'Invalid TTL format'}, status=400)
responses = []
threads = []
client_ip = get_client_ip(request)
logger.info(f"{request.method} request received from IP: {client_ip}")
for file in files:
thread = threading.Thread(target=self._save_file, args=(file, ttl, responses))
threads.append(thread)
for thread in threads:
thread.start()
timeout_event = threading.Event()
timeout_timer = threading.Timer(self.TIMEOUT, lambda: timeout_event.set())
try:
timeout_timer.start()
for thread in threads:
thread.join()
if not timeout_event.is_set():
return Response(responses, status=201)
else:
logger.error('File saving timed out')
return Response({'msg': 'File saving timed out'}, status=500)
finally:
timeout_timer.cancel()
def _save_file(self, file, ttl, responses):
generated_uuid = uuid.uuid4().hex # Generate a UUID
filename = file.name # Get the original filename
save_path = os.path.join(settings.MEDIA_ROOT, generated_uuid) # Use the UUID as the new filename
hasher = hashlib.sha256()
with open(save_path, 'wb') as destination:
for chunk in file.chunks():
hasher.update(chunk)
destination.write(chunk)
hash_signature = hasher.hexdigest()
logger.info(f'File {filename} saved to {save_path} with hash signature {hash_signature}')
try:
grpc_client = client.Client()
print(f'Checking file {filename} with hash signature {hash_signature}')
result = grpc_client.CheckFile(hash_signature)
except Exception as e:
print(f'Error checking file: {str(e)}')
result = False
if result:
response = {
'msg': f"File {filename} is infected with a virus"
}
os.remove(save_path)
responses.append(response)
logger.warning(f'File {filename} is infected with a virus')
return
# Determine the MIME type of the file using python-magic
try:
file_type = magic.Magic()
mime_type = file_type.from_file(save_path)
except Exception as e:
logger.warning(f'Error detecting MIME type: {str(e)}')
mime_type = 'application/octet-stream'
# Store the file path, generated UUID, original filename, MIME type, and other information in the cache
cache.set(generated_uuid, {
'filename': filename, # Original filename
'generated_uuid': generated_uuid, # Generated UUID
'path': save_path,
'mime_type': mime_type,
}, timeout=ttl)
response = {
'key': generated_uuid, # Return the generated UUID
'filename': filename, # Return the original filename
'mime_type': mime_type,
'msg': f"{generated_uuid} successfully set to {filename} with TTL {ttl} seconds",
}
responses.append(response)
logger.info(f'File {filename} successfully saved to cache with key {generated_uuid} and TTL {ttl} seconds')
class ManageItemView(APIView):
def get(self, request, key):
value = cache.get(key)
if not value:
logger.warning(f'Key {key} not found')
raise NotFound("Key not found")
if 'path' not in value:
logger.warning(f'File not found')
raise NotFound("File not found")
file_path = value['path']
if not os.path.exists(file_path):
logger.warning(f'File not found')
raise NotFound("File not found")
with open(file_path, 'rb') as f:
file_data = f.read()
# Retrieve the MIME type from the cache
mime_type = value.get('mime_type', 'application/octet-stream')
response = HttpResponse(file_data, content_type=mime_type)
# Set the Content-Disposition with the original filename
original_filename = value.get('filename', 'file') # Get the original filename from the cache
response['Content-Disposition'] = f'attachment; filename="{quote(original_filename)}"'
logger.info(f'File {file_path} successfully retrieved from cache with key {key}')
return response
def delete(self, request, key):
value = cache.get(key)
if not value:
logger.warning(f'Key {key} not found')
return Response({'msg': 'Not found'}, status=404)
if 'path' in value and os.path.exists(value['path']):
os.remove(value['path'])
cache.delete(key)
logger.info(f'File {value["path"]} successfully deleted from cache with key {key}')
return Response({'msg': f"{key} successfully deleted"}, status=200)
logger.warning(f'File not found')
return Response({'msg': 'File not found'}, status=404)
PRIVATE_IPS_PREFIX = ('10.', '172.', '192.')
def get_client_ip(request):
"""
Get the client's IP address from the request.
This function extracts the client's IP address from various request headers,
taking into account possible proxy servers.
:param request: The HTTP request object.
:return: The client's IP address.
"""
remote_address = request.META.get('HTTP_X_FORWARDED_FOR') or request.META.get('REMOTE_ADDR')
ip = remote_address
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
proxies = x_forwarded_for.split(',')
while proxies and proxies[0].startswith(PRIVATE_IPS_PREFIX):
proxies.pop(0)
if proxies:
ip = proxies[0]
return ip