diff --git a/CTFd/api/v1/files.py b/CTFd/api/v1/files.py index 0a042fee6765040b034f3044159c79a20f1a1f35..11a9621e0ffc50433631524fc5073f90f984d364 100644 --- a/CTFd/api/v1/files.py +++ b/CTFd/api/v1/files.py @@ -336,7 +336,7 @@ class FilesList(Resource): for f in files: # uploads.upload_file(file=f, chalid=req.get('challenge')) try: - obj = uploads.upload_file(file=f, **form_args) + obj = uploads.old_upload_file(file=f, **form_args) except ValueError as e: return { "success": False, diff --git a/CTFd/utils/uploads/__init__.py b/CTFd/utils/uploads/__init__.py index a5d7a9e44f78fe2b5c72aa82e6d2e922ec989d7e..cd9bb2314b42b5146a0a1285add24192568d3f34 100644 --- a/CTFd/utils/uploads/__init__.py +++ b/CTFd/utils/uploads/__init__.py @@ -112,6 +112,58 @@ def upload_file(*args, **kwargs): return file_row +def old_upload_file(*args, **kwargs): + file_obj = kwargs.get("file") + challenge_id = kwargs.get("challenge_id") or kwargs.get("challenge") + page_id = kwargs.get("page_id") or kwargs.get("page") + file_type = kwargs.get("type", "standard") + location = kwargs.get("location") + + # Validate location and default filename to uploaded file's name + parent = None + filename = file_obj.filename + if location: + path = Path(location) + if len(path.parts) != 2: + raise ValueError( + "Location must contain two parts, a directory and a filename" + ) + # Allow location to override the directory and filename + parent = path.parts[0] + filename = path.parts[1] + location = parent + "/" + filename + + model_args = {"type": file_type, "location": location} + + model = Files + if file_type == "challenge": + model = ChallengeFiles + model_args["challenge_id"] = challenge_id + if file_type == "page": + model = PageFiles + model_args["page_id"] = page_id + + # Hash is calculated before upload since S3 file upload closes file object + sha1sum = hash_file(fp=file_obj) + + uploader = get_uploader() + location = uploader.upload(file_obj=file_obj, filename=filename, path=parent) + + model_args["location"] = location + model_args["sha1sum"] = sha1sum + + existing_file = Files.query.filter_by(location=location).first() + if existing_file: + for k, v in model_args.items(): + setattr(existing_file, k, v) + db.session.commit() + file_row = existing_file + else: + file_row = model(**model_args) + db.session.add(file_row) + db.session.commit() + return file_row + def hash_file(fp, algo="sha1"): fp.seek(0) if algo == "sha1":