diff --git a/safeshare/requirements.txt b/safeshare/requirements.txt
index 7d1d420..1558671 100644
--- a/safeshare/requirements.txt
+++ b/safeshare/requirements.txt
@@ -24,12 +24,14 @@ install==1.3.5
isort==5.12.0
mysqlclient==2.2.0
packaging==23.2
+protobuf==4.24.4
psycopg2-binary==2.9.9
pycparser==2.21
pydantic==2.4.2
pydantic_core==2.10.1
PyJWT==1.7.1
PyMySQL==1.1.0
+python-magic==0.4.27
pytz==2023.3.post1
PyYAML==6.0.1
redis==5.0.1
diff --git a/safeshare/safeshare-frontend/src/pages/downloadFile.js b/safeshare/safeshare-frontend/src/pages/downloadFile.js
index 7b509c8..5494738 100644
--- a/safeshare/safeshare-frontend/src/pages/downloadFile.js
+++ b/safeshare/safeshare-frontend/src/pages/downloadFile.js
@@ -43,6 +43,7 @@ function DownloadFile() {
axios.get(`http://127.0.0.1:8000/api/files/${passcode}/`, {responseType: 'blob'})
.then(response => {
let filename = 'downloaded_file'; // Default filename
+ let mimeType = 'application/octet-stream'; // Default MIME type
// Check if the Content-Disposition header exists
if (response.headers['content-disposition']) {
@@ -52,9 +53,15 @@ function DownloadFile() {
if (filenameMatch) {
filename = filenameMatch[1];
}
+
+ // Check if the Content-Type header exists
+ if (response.headers['content-type']) {
+ mimeType = response.headers['content-type'];
+ console.log(mimeType);
+ }
}
- const blob = new Blob([response.data], {type: 'application/octet-stream'});
+ const blob = new Blob([response.data], {type: mimeType});
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
@@ -71,9 +78,9 @@ function DownloadFile() {
})
.catch(error => {
console.log(error);
- openModal()
- // change the error message once error msg add into response
- setErrorcode("File not found")
+ openModal();
+ // Change the error message once error message is added to the response
+ setErrorcode("File not found");
});
}
};
@@ -94,7 +101,9 @@ function DownloadFile() {
{errorMsg}
-
+
diff --git a/safeshare/safeshare/settings.py b/safeshare/safeshare/settings.py
index daef073..9f89d03 100644
--- a/safeshare/safeshare/settings.py
+++ b/safeshare/safeshare/settings.py
@@ -98,6 +98,7 @@ CORS_ALLOWED_ORIGINS = [
CORS_EXPOSE_HEADERS = [
'Content-Disposition',
+ 'Content-Type',
]
TEMPLATES = [
diff --git a/safeshare/safeshare_app/views/file.py b/safeshare/safeshare_app/views/file.py
index a7a29b5..eacff12 100644
--- a/safeshare/safeshare_app/views/file.py
+++ b/safeshare/safeshare_app/views/file.py
@@ -1,186 +1,135 @@
import hashlib
import os
+import sys
import threading
import uuid
from urllib.parse import quote
-from cryptography.fernet import Fernet
+import magic
from django.conf import settings
from django.core.cache import cache
from django.http import HttpResponse
+from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.views import APIView
-import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../utils/safeshare_vdb_client")
import client
class ManageItemsView(APIView):
+ TIMEOUT = 5
+
def post(self, request):
- # Define a timeout value (in seconds)
- timeout = 5
-
- # Get the list of files and the TTL value from the request data
files = request.FILES.getlist('file')
- ttl = request.data.get('ttl')
-
- if not ttl:
- # Set ttl to 3 days
- ttl = 259200 # 3 * 24 * 60 * 60
+ ttl = request.data.get('ttl') or 259200 # Default TTL is 3 days
try:
- # Convert the TTL to an integer
ttl = int(ttl)
-
if ttl <= 0:
return Response({'msg': 'TTL must be a positive integer'}, status=400)
except ValueError:
return Response({'msg': 'Invalid TTL format'}, status=400)
- def save_file_locally(file):
- key = uuid.uuid4().hex
- filename = file.name
- save_path = os.path.join(settings.MEDIA_ROOT, filename)
- hasher = hashlib.sha256()
-
- # Save the file locally
- with open(save_path, 'wb') as destination:
- for chunk in file.chunks():
- hasher.update(chunk)
- destination.write(chunk)
-
- # Get the hash signature
- hash_signature = hasher.hexdigest()
-
- # If RPC client import fails, skip virus scan
- # Call RPC For virus scan
- try:
- grpc_client = client.Client()
- result = grpc_client.CheckFile(hash_signature)
- except Exception as e:
- result = False
-
- if result:
- response = {
- 'msg': f"File {filename} is infected with a virus"
- }
- os.remove(save_path)
- responses.append(response)
- return Response(responses, status=400)
-
- # Generate a random UUID to use as the encryption key
- encryption_key = Fernet.generate_key()
- cipher_suite = Fernet(encryption_key)
-
- # Encrypted Data Buffer
- encrypted_data = b""
-
- # Encrypt the filename
- encrypted_filename = cipher_suite.encrypt(filename.encode())
-
- # Reopen the file to encrypt it with the encryption key and Fernet algorithm
- with open(save_path, 'rb') as source_file:
- for chunk in source_file:
- encrypted_chunk = cipher_suite.encrypt(chunk)
- encrypted_data += encrypted_chunk
-
- # New save path
- save_path = os.path.join(settings.MEDIA_ROOT, str(encrypted_filename))
-
- # Overwrite the file with the encrypted data
- with open(save_path, 'wb') as destination:
- destination.write(encrypted_data)
-
-
- # Store the file path and encryption key in the cache with the provided TTL
- cache.set(key,
- {
- 'filename': encrypted_filename,
- 'path': save_path,
- 'encryption_key': encryption_key,
- },
- timeout=ttl)
-
- response = {
- 'key': key,
- 'filename': encrypted_filename,
- 'msg': f"{key} successfully set to {filename} with TTL {ttl} seconds",
- }
-
- # Append the response to the shared responses list
- responses.append(response)
-
- # Create a list to store the responses for each file
responses = []
+ threads = []
- # Create a thread for each file
- file_threads = []
for file in files:
- file_thread = threading.Thread(target=save_file_locally, args=(file,))
- file_threads.append(file_thread)
+ thread = threading.Thread(target=self._save_file, args=(file, ttl, responses))
+ threads.append(thread)
- # Start all file-saving threads
- for file_thread in file_threads:
- file_thread.start()
+ for thread in threads:
+ thread.start()
- # Use a Timer to add a timeout
timeout_event = threading.Event()
- timeout_timer = threading.Timer(timeout, lambda: timeout_event.set())
+ timeout_timer = threading.Timer(self.TIMEOUT, lambda: timeout_event.set())
try:
- # Start the timer
timeout_timer.start()
+ for thread in threads:
+ thread.join()
- # Wait for all file-saving threads to complete
- for file_thread in file_threads:
- file_thread.join()
-
- # Check if the threads completed without a timeout
if not timeout_event.is_set():
return Response(responses, status=201)
else:
return Response({'msg': 'File saving timed out'}, status=500)
finally:
- # Always cancel the timer to prevent it from firing after the threads complete
timeout_timer.cancel()
+ def _save_file(self, file, ttl, responses):
+ key = uuid.uuid4().hex
+ filename = file.name
+ save_path = os.path.join(settings.MEDIA_ROOT, filename)
+ hasher = hashlib.sha256()
+
+ with open(save_path, 'wb') as destination:
+ for chunk in file.chunks():
+ hasher.update(chunk)
+ destination.write(chunk)
+
+ hash_signature = hasher.hexdigest()
+
+ try:
+ grpc_client = client.Client()
+ result = grpc_client.CheckFile(hash_signature)
+ except Exception as e:
+ result = False
+
+ if result:
+ response = {
+ 'msg': f"File {filename} is infected with a virus"
+ }
+ os.remove(save_path)
+ responses.append(response)
+ return
+
+ # Determine the MIME type of the file using python-magic
+ file_type = magic.Magic()
+ mime_type = file_type.from_file(save_path)
+
+ # Store the file path, filename, MIME type, and other information in the cache
+ cache.set(key, {
+ 'filename': filename,
+ 'path': save_path,
+ 'mime_type': mime_type, # Store the MIME type
+ }, timeout=ttl)
+
+ response = {
+ 'key': key,
+ 'filename': filename,
+ 'mime_type': mime_type, # Include the MIME type in the response
+ 'msg': f"{key} successfully set to {filename} with TTL {ttl} seconds",
+ }
+ responses.append(response)
+
class ManageItemView(APIView):
def get(self, request, key):
value = cache.get(key)
if not value:
- return Response({'msg': 'Not found'}, status=404)
+ raise NotFound("Key not found")
if 'path' not in value:
- return Response({'msg': 'File not found'}, status=404)
+ raise NotFound("File not found")
file_path = value['path']
if not os.path.exists(file_path):
- return Response({'msg': 'File not found'}, status=404)
+ raise NotFound("File not found")
- # Retrieve the encryption key from the cache
- encryption_key = value.get('encryption_key')
-
- if not encryption_key:
- return Response({'msg': 'Encryption key not found'}, status=404)
-
- # Decrypt the filename
- cipher_suite = Fernet(encryption_key)
- decrypted_filename = cipher_suite.decrypt(value['filename']).decode()
-
- # Decrypt the file content
with open(file_path, 'rb') as f:
- encrypted_data = f.read()
- decrypted_data = cipher_suite.decrypt(encrypted_data)
+ file_data = f.read()
- response = HttpResponse(decrypted_data, content_type='application/octet-stream')
+ # Retrieve the MIME type from the cache
+ mime_type = value.get('mime_type', 'application/octet-stream')
+
+ response = HttpResponse(file_data, content_type=mime_type)
+
+ # Set the Content-Disposition with the original filename
+ response['Content-Disposition'] = f'attachment; filename="{quote(os.path.basename(file_path))}"'
- # Set the Content-Disposition with the decrypted filename
- response['Content-Disposition'] = f'attachment; filename="{quote(decrypted_filename)}"'
return response
def delete(self, request, key):